Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from ..torch_core import *
from ..basic_data import DataBunch
from ..callback import *
from ..basic_train import Learner,LearnerCallback
from torch.utils.data.sampler import WeightedRandomSampler
__all__ = ['OverSamplingCallback']
class OverSamplingCallback(LearnerCallback):
def __init__(self,learn:Learner,weights:torch.Tensor=None):
super().__init__(learn)
self.weights = weights
def on_train_begin(self, **kwargs):
ds,dl = self.data.train_ds,self.data.train_dl
self.labels = ds.y.items
assert np.issubdtype(self.labels.dtype, np.integer), "Can only oversample integer values"
_,self.label_counts = np.unique(self.labels,return_counts=True)
if self.weights is None: self.weights = torch.DoubleTensor((1/self.label_counts)[self.labels])
self.total_len_oversample = int(self.data.c*np.max(self.label_counts))
sampler = WeightedRandomSampler(self.weights, self.total_len_oversample)
self.data.train_dl = dl.new(shuffle=False, sampler=sampler)
from .torch_core import *
from .basic_train import Learner,LearnerCallback
from torch.nn.parallel import DistributedDataParallel, DataParallel
from torch.utils.data.distributed import DistributedSampler
from fastai.text import TextLMDataBunch
__all__ = ['DistributedRecorder', 'DistributedTrainer', 'read_metrics', 'setup_distrib']
def rnn_reset(self):
if hasattr(self.module, 'reset'): self.module.reset()
DistributedDataParallel.reset = rnn_reset
class ParallelTrainer(LearnerCallback):
_order = -20
def on_train_begin(self, **kwargs): self.learn.model = DataParallel(self.learn.model)
def on_train_end (self, **kwargs): self.learn.model = self.learn.model.module
class DistributedTrainer(LearnerCallback):
_order = -20 # Needs to run before the recorder
def __init__(self, learn:Learner, cuda_id:int=0):
super().__init__(learn)
self.cuda_id,self.train_sampler = cuda_id,None
def _change_dl(self, dl, shuffle):
old_dl = dl
sampler = OurDistributedSampler(dl.dataset, shuffle=shuffle)
new_dl = dl.new(shuffle=False, sampler=sampler)
return old_dl,new_dl,sampler
from fastai.basic_train import Learner, LearnerCallback
from fastai.vision.gan import GANLearner
class GANSaveCallback(LearnerCallback):
"""A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
def __init__(
self,
learn: GANLearner,
learn_gen: Learner,
filename: str,
save_iters: int = 1000,
):
super().__init__(learn)
self.learn_gen = learn_gen
self.filename = filename
self.save_iters = save_iters
def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
if iteration == 0:
from ..basic_train import Learner, LearnerCallback
__all__ = ['GeneralScheduler', 'TrainingPhase']
@dataclass
class TrainingPhase():
"Schedule hyper-parameters for a phase of `length` iterations."
length:int
def __post_init__(self): self.scheds = dict()
def schedule_hp(self, name, vals, anneal=None):
"Adds a schedule for `name` between `vals` using `anneal`."
self.scheds[name] = Scheduler(vals, self.length, anneal)
return self
class GeneralScheduler(LearnerCallback):
"Schedule multiple `TrainingPhase` for a `Learner`."
def __init__(self, learn:Learner, phases:Collection[TrainingPhase], start_epoch:int=None):
super().__init__(learn)
self.phases,self.start_epoch = phases,start_epoch
def on_train_begin(self, epoch:int, **kwargs:Any)->None:
"Initialize the schedulers for training."
res = {'epoch':self.start_epoch} if self.start_epoch is not None else None
self.start_epoch = ifnone(self.start_epoch, epoch)
self.scheds = [p.scheds for p in self.phases]
self.opt = self.learn.opt
for k,v in self.scheds[0].items():
v.restart()
self.opt.set_stat(k, v.start)
self.idx_s = 0
return res
"Implements [mixup](https://arxiv.org/abs/1710.09412) training method"
from ..torch_core import *
from ..callback import *
from ..basic_train import Learner, LearnerCallback
__all__ = ["MixUpCallback", "MixUpLoss"]
class MixUpCallback(LearnerCallback):
"Callback that creates the mixed-up input and target."
def __init__(self, learn:Learner, alpha:float=0.4, stack_x:bool=False, stack_y:bool=True):
super().__init__(learn)
self.alpha,self.stack_x,self.stack_y = alpha,stack_x,stack_y
def on_train_begin(self, **kwargs):
if self.stack_y: self.learn.loss_func = MixUpLoss(self.learn.loss_func)
def on_batch_begin(self, last_input, last_target, train, **kwargs):
"Applies mixup to `last_input` and `last_target` if `train`."
if not train: return
lambd = np.random.beta(self.alpha, self.alpha, last_target.size(0))
lambd = np.concatenate([lambd[:,None], 1-lambd[:,None]], 1).max(1)
lambd = last_input.new(lambd)
shuffle = torch.randperm(last_target.size(0)).to(last_input.device)
x1, y1 = last_input[shuffle], last_target[shuffle]
def __init__(self, loss_funcG:Callable, loss_funcC:Callable, gan_model:GANModule):
super().__init__()
self.loss_funcG,self.loss_funcC,self.gan_model = loss_funcG,loss_funcC,gan_model
def generator(self, output, target):
"Evaluate the `output` with the critic then uses `self.loss_funcG` to combine it with `target`."
fake_pred = self.gan_model.critic(output)
return self.loss_funcG(fake_pred, target, output)
def critic(self, real_pred, input):
"Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.loss_funcD`."
fake = self.gan_model.generator(input.requires_grad_(False)).requires_grad_(True)
fake_pred = self.gan_model.critic(fake)
return self.loss_funcC(real_pred, fake_pred)
class GANTrainer(LearnerCallback):
"Handles GAN Training."
_order=-20
def __init__(self, learn:Learner, switch_eval:bool=False, clip:float=None, beta:float=0.98, gen_first:bool=False,
show_img:bool=True):
super().__init__(learn)
self.switch_eval,self.clip,self.beta,self.gen_first,self.show_img = switch_eval,clip,beta,gen_first,show_img
self.generator,self.critic = self.model.generator,self.model.critic
def _set_trainable(self):
train_model = self.generator if self.gen_mode else self.critic
loss_model = self.generator if not self.gen_mode else self.critic
requires_grad(train_model, True)
requires_grad(loss_model, False)
if self.switch_eval:
train_model.train()
loss_model.eval()
from queue import Queue
import statistics
import torchvision.utils as vutils
from abc import ABC
#This is an optional dependency in fastai. Must install separately.
try: from tensorboardX import SummaryWriter
except: print("To use this tracker, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results")
__all__=['LearnerTensorboardWriter', 'GANTensorboardWriter', 'ImageGenTensorboardWriter']
#---Example usage (applies to any of the callbacks)---
# proj_id = 'Colorize'
# tboard_path = Path('data/tensorboard/' + proj_id)
# learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=tboard_path, name='GanLearner'))
class LearnerTensorboardWriter(LearnerCallback):
"Broadly useful callback for Learners that writes to Tensorboard. Writes model histograms, losses/metrics, and gradient stats."
def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):
super().__init__(learn=learn)
self.base_dir,self.name,self.loss_iters,self.hist_iters,self.stats_iters = base_dir,name,loss_iters,hist_iters,stats_iters
log_dir = base_dir/name
self.tbwriter = SummaryWriter(str(log_dir))
self.hist_writer = HistogramTBWriter()
self.stats_writer = ModelStatsTBWriter()
self.graph_writer = GraphTBWriter()
self.data = None
self.metrics_root = '/metrics/'
self._update_batches_if_needed()
def _get_new_batch(self, ds_type:DatasetType)->Collection[Tensor]:
"Retrieves new batch of DatasetType, and detaches it."
return self.learn.data.one_batch(ds_type=ds_type, detach=True, denorm=False, cpu=False)