Metadata-Version: 2.1
Name: tfimm
Version: 0.1.0
Summary: TensorFlow port of PyTorch Image Models (timm) - image models with pretrained weights
Home-page: https://github.com/martinsbruveris/tensorflow-image-models
License: Apache-2.0
Keywords: tensorfow,pretrained models,visual transformer
Author: Martins Bruveris
Author-email: martins.bruveris@gmail.com
Requires-Python: >=3.7,<3.10
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Education
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development
Classifier: Topic :: Software Development :: Libraries
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Provides-Extra: timm
Provides-Extra: train
Requires-Dist: markdown (==3.3.4)
Requires-Dist: numpy
Requires-Dist: pyyaml; extra == "timm"
Requires-Dist: tensorflow (>=2.5,<3.0)
Requires-Dist: tensorflow-datasets; extra == "train"
Requires-Dist: timm; extra == "timm"
Requires-Dist: tqdm; extra == "train"
Project-URL: Repository, https://github.com/martinsbruveris/tensorflow-image-models
Description-Content-Type: text/markdown

# TensorFlow-Image-Models

- [Introduction](#introduction)
- [Usage](#usage)
- [Models](#models)
- [License](#license)

## Introduction

TensorfFlow-Image-Models (`tfimm`) is a collection of image models with pretrained
weights, obtained by porting architectures from 
[timm](https://github.com/rwightman/pytorch-image-models) to TensorFlow. The hope is
that the number of available architectures will grow over time. For now it contains
vision transformers (ViT and DeiT) and ResNets.

This work would now have been possible wihout Ross Wightman's `timm` library and the
work on PyTorch/TensorFlow interoperability in HuggingFace's `transformer` repository.
I tried to make sure all source material is acknowledged. Please let me know if I have
missed something.

## Usage

### Installation 

The package can be installed via `pip`,

```shell
pip install tfimm
```

To load pretrained weights, `timm` needs to be installed. This is an optional 
dependency and can be installed via

```shell
pip install tfimm[timm]
```

### Creating models

To load pretrained models use

```python
import tfimm

model = tfimm.create_model("vit_tiny_patch16_224", pretrained="timm")
```

We can list available models with pretrained weights via

```python
import tfimm

print(tfimm.list_models(pretrained="timm"))
```

Most models are pretrained on ImageNet or ImageNet-21k. If we want to use them for other
tasks we need to change the number of classes in the classifier or remove the 
classifier altogether. We can do this by setting the `nb_classes` parameter in 
`create_model`. If `nb_classes=0`, the model will have no classification layer. If
`nb_classes` is set to a value different from the default model config, the 
classification layer will be randomly initialized, while all other weights will be
copied from the pretrained model.

### Saving and loading models

All models are subclassed from `tf.keras.Model` (they are _not_ functional models).
They can still be saved and loaded using the `SavedModel` format.

```
>>> import tesnorflow as tf
>>> import tfimm
>>> model = tfimm.create_model("vit_tiny_patch16_224")
>>> type(model)
<class 'tfimm.architectures.vit.ViT'>
>>> model.save("/tmp/my_model")
>>> loaded_model = tf.keras.models.load_model("/tmp/my_model")
>>> type(loaded_model)
<class 'tfimm.architectures.vit.ViT'>
```

For this to work, the `tfimm` library needs to be imported before the model is loaded,
since during the import process, `tfimm` is registering custom models with Keras.
Otherwise, we obtain the following output

```
>>> import tensorflow as tf
>>> loaded_model = tf.keras.models.load_model("/tmp/my_model")
>>> type(loaded_model)
<class 'keras.saving.saved_model.load.Custom>ViT'>
```

## Models

The following architectures are currently available:

- DeiT (vision transformer) 
  [\[github\]](https://github.com/facebookresearch/deit)
  - Training data-efficient image transformers & distillation through attention. 
    [\[arXiv:2012.12877\]](https://arxiv.org/abs/2012.12877) 
- ViT (vision transformer) 
  [\[github\]](https://github.com/google-research/vision_transformer)
  - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.
    [\[arXiv:2010.11929\]](https://arxiv.org/abs/2010.11929)
  - How to train your ViT? Data, Augmentation, and Regularization in Vision 
    Transformers. [\[arXiv:2106.10270\]](https://arxiv.org/abs/2106.10270)
  - Includes models trained with the SAM optimizer: Sharpness-Aware Minimization for 
    Efficiently Improving Generalization. 
    [\[arXiv:2010.01412\]](https://arxiv.org/abs/2010.01412)
  - Includes models from: ImageNet-21K Pretraining for the Masses
    [\[arXiv:2104.10972\]](https://arxiv.org/abs/2104.10972) 
    [\[github\]](https://github.com/Alibaba-MIIL/ImageNet21K)
- ResNet (work in progress, most available weights are from `timm`)
  - Deep Residual Learning for Image Recognition. 
    [\[arXiv:1512.03385\]](https://arxiv.org/abs/1512.03385)

## License

This repository is released under the Apache 2.0 license as found in the 
[LICENSE](LICENSE) file.
