Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_model_to_learner():
# Test if the function loads an ImageNet model (ResNet) trainer
learn = model_to_learner(models.resnet34(pretrained=True))
assert len(learn.data.classes) == 1000 # Check Image net classes
assert isinstance(learn.model, models.ResNet)
# Test with SqueezeNet
learn = model_to_learner(models.squeezenet1_0())
assert len(learn.data.classes) == 1000
assert isinstance(learn.model, models.SqueezeNet)
# Split squeezenet model on maxpool layers
def _squeezenet_split(m:nn.Module): return (m[0][0][5], m[0][0][8], m[1])
def _densenet_split(m:nn.Module): return (m[0][0][7],m[1])
def _vgg_split(m:nn.Module): return (m[0][0][22],m[1])
def _alexnet_split(m:nn.Module): return (m[0][0][6],m[1])
_default_meta = {'cut':None, 'split':_default_split}
_resnet_meta = {'cut':-2, 'split':_resnet_split }
_squeezenet_meta = {'cut':-1, 'split': _squeezenet_split}
_densenet_meta = {'cut':-1, 'split':_densenet_split}
_vgg_meta = {'cut':-1, 'split':_vgg_split}
_alexnet_meta = {'cut':-1, 'split':_alexnet_split}
model_meta = {
models.resnet18 :{**_resnet_meta}, models.resnet34: {**_resnet_meta},
models.resnet50 :{**_resnet_meta}, models.resnet101:{**_resnet_meta},
models.resnet152:{**_resnet_meta},
models.squeezenet1_0:{**_squeezenet_meta},
models.squeezenet1_1:{**_squeezenet_meta},
models.densenet121:{**_densenet_meta}, models.densenet169:{**_densenet_meta},
models.densenet201:{**_densenet_meta}, models.densenet161:{**_densenet_meta},
models.vgg11_bn:{**_vgg_meta}, models.vgg13_bn:{**_vgg_meta}, models.vgg16_bn:{**_vgg_meta}, models.vgg19_bn:{**_vgg_meta},
models.alexnet:{**_alexnet_meta}}
def cnn_config(arch):
"Get the metadata associated with `arch`."
torch.backends.cudnn.benchmark = True
return model_meta.get(arch, _default_meta)
def has_pool_type(m):
# Set parameters based on your selected model.
# In[6]:
if MODEL_TYPE == "high_accuracy":
ARCHITECTURE = models.resnet50
IM_SIZE = 500
if MODEL_TYPE == "fast_inference":
ARCHITECTURE = models.resnet18
IM_SIZE = 300
if MODEL_TYPE == "small_size":
ARCHITECTURE = models.squeezenet1_1
IM_SIZE = 300
# We'll automatically determine if your dataset is a multi-label or traditional (single-label) classification problem. To do so, we'll use the `is_data_multilabel` helper function. In order to detect whether or not a dataset is multi-label, the helper function will check to see if the datapath contains a csv file that has a column 'labels' where the values are space-delimited. You can inspect the function by calling `is_data_multilabel??`.
#
# This function assumes that your multi-label dataset is structured in the recommended format shown in the [multilabel notebook](02_multilabel_classification.ipynb).
# In[7]:
multilabel = is_data_multilabel(DATA_PATH)
metric = accuracy if not multilabel else hamming_accuracy
# ## Pre-processing <a name="preprocessing"></a>
#
# Make sure that `MODEL_TYPE` is correctly set.
# In[5]:
assert MODEL_TYPE in ["high_accuracy", "fast_inference", "small_size"]
# Set parameters based on your selected model.
# In[6]:
if MODEL_TYPE == "high_accuracy":
ARCHITECTURE = models.resnet50
IM_SIZE = 500
if MODEL_TYPE == "fast_inference":
ARCHITECTURE = models.resnet18
IM_SIZE = 300
if MODEL_TYPE == "small_size":
ARCHITECTURE = models.squeezenet1_1
IM_SIZE = 300
# We'll automatically determine if your dataset is a multi-label or traditional (single-label) classification problem. To do so, we'll use the `is_data_multilabel` helper function. In order to detect whether or not a dataset is multi-label, the helper function will check to see if the datapath contains a csv file that has a column 'labels' where the values are space-delimited. You can inspect the function by calling `is_data_multilabel??`.
#
# This function assumes that your multi-label dataset is structured in the recommended format shown in the [multilabel notebook](02_multilabel_classification.ipynb).
# In[7]: