Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def predict(self, paths):
item_list = ImageList(paths)
self.learn.data.add_test(item_list)
pred_probs, _ = self.learn.get_preds(DatasetType.Test)
pred_probs = pred_probs.numpy()
preds = pred_probs.argmax(axis=1)
labels = [self.learn.data.classes[label_idx] for label_idx in preds]
return pred_probs[:, self.target_class_idx], labels
return cls(learn_gen.data, learn_gen.model, learn_crit.model, *losses, switcher=switcher, **learn_kwargs)
@classmethod
def wgan(cls, data:DataBunch, generator:nn.Module, critic:nn.Module, switcher:Callback=None, clip:float=0.01, **learn_kwargs):
"Create a WGAN from `data`, `generator` and `critic`."
return cls(data, generator, critic, NoopLoss(), WassersteinLoss(), switcher=switcher, clip=clip, **learn_kwargs)
class NoisyItem(ItemBase):
"An random `ItemBase` of size `noise_sz`."
def __init__(self, noise_sz): self.obj,self.data = noise_sz,torch.randn(noise_sz, 1, 1)
def __str__(self): return ''
def apply_tfms(self, tfms, **kwargs):
for f in listify(tfms): f.resolve()
return self
class GANItemList(ImageList):
"`ItemList` suitable for GANs."
_label_cls = ImageList
def __init__(self, items, noise_sz:int=100, **kwargs):
super().__init__(items, **kwargs)
self.noise_sz = noise_sz
self.copy_new.append('noise_sz')
def get(self, i): return NoisyItem(self.noise_sz)
def reconstruct(self, t): return NoisyItem(t.size(0))
def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Shows `ys` (target images) on a figure of `figsize`."
super().show_xys(ys, xs, imgsize=imgsize, figsize=figsize, **kwargs)
def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
def from_name_func(cls, path:PathOrStr, fnames:FilePathList, label_func:Callable, valid_pct:float=0.2, seed:int=None,
**kwargs):
"Create from list of `fnames` in `path` with `label_func`."
src = ImageList(fnames, path=path).split_by_rand_pct(valid_pct, seed)
return cls.create_from_ll(src.label_from_func(label_func), **kwargs)
def __init__(self, ds:ItemList): self.classes = ds.classes
def process(self, ds:ItemList): ds.classes,ds.c = self.classes,len(self.classes)
class SegmentationLabelList(ImageList):
"`ItemList` for segmentation masks."
_processor=SegmentationProcessor
def __init__(self, items:Iterator, classes:Collection=None, **kwargs):
super().__init__(items, **kwargs)
self.copy_new.append('classes')
self.classes,self.loss_func = classes,CrossEntropyFlat(axis=1)
def open(self, fn): return open_mask(fn, after_open=self.after_open)
def analyze_pred(self, pred, thresh:float=0.5): return pred.argmax(dim=0)[None]
def reconstruct(self, t:Tensor): return ImageSegment(t)
class SegmentationItemList(ImageList):
"`ItemList` suitable for segmentation tasks."
_label_cls,_square_show_res = SegmentationLabelList,False
class PointsProcessor(PreProcessor):
"`PreProcessor` that stores the number of targets for point regression."
def __init__(self, ds:ItemList): self.c = len(ds.items[0].reshape(-1))
def process(self, ds:ItemList): ds.c = self.c
class PointsLabelList(ItemList):
"`ItemList` for points."
_processor = PointsProcessor
def __init__(self, items:Iterator, **kwargs):
super().__init__(items, **kwargs)
self.loss_func = MSELossFlat()
def get(self, i):
self.loss_func = MSELossFlat()
def get(self, i):
o = super().get(i)
return ImagePoints(FlowField(_get_size(self.x,i), o), scale=True)
def analyze_pred(self, pred, thresh:float=0.5): return pred.view(-1,2)
def reconstruct(self, t, x): return ImagePoints(FlowField(x.size, t), scale=False)
class PointsItemList(ImageList):
"`ItemList` for `Image` to `ImagePoints` tasks."
_label_cls,_square_show_res = PointsLabelList,False
class ImageImageList(ImageList):
"`ItemList` suitable for `Image` to `Image` tasks."
_label_cls,_square_show,_square_show_res = ImageList,False,False
def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Show the `xs` (inputs) and `ys`(targets) on a figure of `figsize`."
axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize)
for i, (x,y) in enumerate(zip(xs,ys)):
x.show(ax=axs[i,0], **kwargs)
y.show(ax=axs[i,1], **kwargs)
plt.tight_layout()
def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`."
title = 'Input / Prediction / Target'
axs = subplots(len(xs), 3, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14)
for i,(x,y,z) in enumerate(zip(xs,ys,zs)):
x.show(ax=axs[i,0], **kwargs)
y.show(ax=axs[i,2], **kwargs)
def from_df(cls, path:PathOrStr, df:pd.DataFrame, folder:PathOrStr=None, label_delim:str=None, valid_pct:float=0.2,
seed:int=None, fn_col:IntsOrStrs=0, label_col:IntsOrStrs=1, suffix:str='', **kwargs:Any)->'ImageDataBunch':
"Create from a `DataFrame` `df`."
src = (ImageList.from_df(df, path=path, folder=folder, suffix=suffix, cols=fn_col)
.split_by_rand_pct(valid_pct, seed)
.label_from_df(label_delim=label_delim, cols=label_col))
return cls.create_from_ll(src, **kwargs)
def from_lists(cls, path:PathOrStr, fnames:FilePathList, labels:Collection[str], valid_pct:float=0.2, seed:int=None,
item_cls:Callable=None, **kwargs):
"Create from list of `fnames` in `path`."
item_cls = ifnone(item_cls, ImageList)
fname2label = {f:l for (f,l) in zip(fnames, labels)}
src = (item_cls(fnames, path=path).split_by_rand_pct(valid_pct, seed)
.label_from_func(lambda x:fname2label[x]))
return cls.create_from_ll(src, **kwargs)
class PointsLabelList(ItemList):
"`ItemList` for points."
_processor = PointsProcessor
def __init__(self, items:Iterator, **kwargs):
super().__init__(items, **kwargs)
self.loss_func = MSELossFlat()
def get(self, i):
o = super().get(i)
return ImagePoints(FlowField(_get_size(self.x,i), o), scale=True)
def analyze_pred(self, pred, thresh:float=0.5): return pred.view(-1,2)
def reconstruct(self, t, x): return ImagePoints(FlowField(x.size, t), scale=False)
class PointsItemList(ImageList):
"`ItemList` for `Image` to `ImagePoints` tasks."
_label_cls,_square_show_res = PointsLabelList,False
class ImageImageList(ImageList):
"`ItemList` suitable for `Image` to `Image` tasks."
_label_cls,_square_show,_square_show_res = ImageList,False,False
def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
"Show the `xs` (inputs) and `ys`(targets) on a figure of `figsize`."
axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize)
for i, (x,y) in enumerate(zip(xs,ys)):
x.show(ax=axs[i,0], **kwargs)
y.show(ax=axs[i,1], **kwargs)
plt.tight_layout()
def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
def single_from_classes(path:Union[Path, str], classes:Collection[str], ds_tfms:TfmList=None, **kwargs):
"Create an empty `ImageDataBunch` in `path` with `classes`. Typically used for inference."
warn("""This method is deprecated and will be removed in a future version, use `load_learner` after
`Learner.export()`""", DeprecationWarning)
sd = ImageList([], path=path, ignore_empty=True).split_none()
return sd.label_const(0, label_cls=CategoryList, classes=classes).transform(ds_tfms, **kwargs).databunch()
def from_folder(cls, path:PathOrStr, train:PathOrStr='train', valid:PathOrStr='valid', test:Optional[PathOrStr]=None,
valid_pct=None, seed:int=None, classes:Collection=None, **kwargs:Any)->'ImageDataBunch':
"Create from imagenet style dataset in `path` with `train`,`valid`,`test` subfolders (or provide `valid_pct`)."
path=Path(path)
il = ImageList.from_folder(path, exclude=test)
if valid_pct is None: src = il.split_by_folder(train=train, valid=valid)
else: src = il.split_by_rand_pct(valid_pct, seed)
src = src.label_from_folder(classes=classes)
return cls.create_from_ll(src, test=test, **kwargs)