deeptrain package¶
Subpackages¶
- deeptrain.backend package
- deeptrain.util package
- Subpackages
- Submodules
- deeptrain.util.algorithms module
- deeptrain.util.configs module
- deeptrain.util._default_configs module
- deeptrain.util.data_loaders module
- deeptrain.util.experimental module
- deeptrain.util.logging module
- deeptrain.util.misc module
- deeptrain.util.preprocessors module
- deeptrain.util.saving module
- deeptrain.util.searching module
- deeptrain.util.training module
- deeptrain.util._traingen_utils module
- Module contents
Submodules¶
deeptrain.callbacks module¶
-
deeptrain.callbacks.
binary_preds_per_iteration_cb
(self)¶ Suited for binary classification sigmoid outputs. See
binary_preds_per_iteration()
.
-
deeptrain.callbacks.
binary_preds_distribution_cb
(self)¶ Suited for binary classification sigmoid outputs. See
binary_preds_distribution()
.
-
deeptrain.callbacks.
infer_train_hist_cb
(self)¶ Suited for binary classification sigmoid outputs. See
infer_train_hist()
.
-
deeptrain.callbacks.
make_layer_hists_cb
(_id='*', mode='weights', x=None, y=None, omit_names='bias', share_xy=(0, 0), configs=None, **kw)¶ Layer histograms grid callback. See
layer_hists()
.
-
class
deeptrain.callbacks.
TraingenCallback
¶ Bases:
object
Required base class for
callbacks
objects used byTrainGenerator
. Enables usingTrainGenerator
attributes by assigning it toself
viainit_with_traingen
. Methods are called byTrainGenerator
at several stages: train, validation, save, load,__init__
.stage
is an optional argument to every method (exceptinit_with_traingen
) to allow finer-grained control, particularly for('val_end', 'train:epoch')
.Methods not implemented by the inheriting class will be skipped by catching
NotImplementedError
.__init__
sets a private attribute_counters
, which can be used along a “freq” to call methods every Nth callback, instead of every callback. Example:RandomSeedSetter
.-
init_with_traingen
(traingen=None)¶ Called by
TrainGenerator.__init__()
, passing inself
(TrainGenerator
instance).
-
on_train_iter_end
(stage=None)¶ Called by
_on_iter_end
, withinTrainGenerator._train_postiter_processing()
withstage='train:iter'
.
-
on_train_batch_end
(stage=None)¶ Called by
_on_batch_end
, withinTrainGenerator._train_postiter_processing()
, withstage='train:batch'
.
-
on_train_epoch_end
(stage=None)¶ Called by
_on_epoch_end
, withinTrainGenerator._train_postiter_processing()
, withstage='train:epoch'
.
-
on_val_iter_end
(stage=None)¶ Called by
_on_iter_end
, withinTrainGenerator._val_postiter_processing()
, withstage='val:iter'
.
-
on_val_batch_end
(stage=None)¶ Called by
_on_batch_end
, withinTrainGenerator._val_postiter_processing()
, withstage='val:batch'
.
-
on_val_epoch_end
(stage=None)¶ Called by
_on_epoch_end
, withinTrainGenerator._val_postiter_processing()
, withstage='val:epoch'
.
-
on_val_end
(stage=None)¶ Called by
TrainGenerator._on_val_end()
, with:stage=('val_end', 'train:epoch')
ifTrainGenerator.datagen.all_data_exhausted
stage='val_end'
otherwise
-
on_save
(stage=None)¶ Called by
TrainGenerator.save()
withstage='save'
.
-
on_load
(stage=None)¶ Called by
TrainGenerator.load()
withstage='load'
.
-
-
class
deeptrain.callbacks.
RandomSeedSetter
(seeds=None, freq={'train:epoch': 2})¶ Bases:
deeptrain.callbacks.TraingenCallback
Periodically sets
random
,numpy
, TensorFlow graph-level, and TensorFlow global seeds.- Arguments:
- seeds: dict / None
- Dict of initial seeds; if None, will default all to 0. E.g.:
{'random': 0, 'numpy': 0, 'tf-graph': 1, 'tf-global': 2}
. - freq: dict
- When to set / reset the seeds. E.g.
{'train:epoch': 2}
(default) will set every 2 train epochs. By default, seeds are incremented by 1 to avoid repeating random sequences with each setting.
Methods can be overridden to increment by a different amount or not at all, and to clear keras session & reset TF default graph, which doesn’t happen by default.
-
seeds
¶
-
set_seeds
(increment=0, seeds=None, reset_graph=False, verbose=1)¶ Sets seeds.
increment
will add to each ofself.seeds
and update it. Ifseeds
is not None, will overrideincrement
.
-
classmethod
_set_seeds
(seeds=None, reset_graph=False, verbose=1)¶ See
set_seeds()
;_set_seeds
can be used without instantiating class, as is used indeeptrain.set_seeds()
.
-
classmethod
reset_graph
(verbose=1)¶ Clears keras session, and resets TensorFlow default graph.
NOTE: after calling, best practice is to re-instantiate model, else some internal operations may fail due to elements from different graphs interacting (from docs: “using any previously created
tf.Operation
ortf.Tensor
objects after calling this function will result in undefined behavior”).
-
call_on_freq
()¶
-
on_train_iter_end
(stage)¶
-
on_train_batch_end
(stage)¶
-
on_train_epoch_end
(stage)¶
-
on_val_iter_end
(stage)¶
-
on_val_batch_end
(stage)¶
-
on_val_epoch_end
(stage)¶
-
on_val_end
(stage)¶
-
on_save
(stage)¶
-
on_load
(stage)¶
-
class
deeptrain.callbacks.
VizAE2D
(n_images=8, save_images=False)¶ Bases:
deeptrain.callbacks.TraingenCallback
Image AutoEncoder reconstruction visualizer.
Plots
n_images
images of original & reconstructed (model outputs) in two rows, side-by-side vertically.-
on_val_end
(stage=None)¶ Called by
TrainGenerator._on_val_end()
, with:stage=('val_end', 'train:epoch')
ifTrainGenerator.datagen.all_data_exhausted
stage='val_end'
otherwise
-
viz
()¶
-
get_data
()¶
-
-
class
deeptrain.callbacks.
TraingenLogger
(savedir, configs, loadpath=None, get_data_fn=None, get_labels_fn=None, gather_fns=None, logname='datalog_', init_log_id=None)¶ Bases:
deeptrain.callbacks.TraingenCallback
TensorBoard-like logger, gathering layer weights, outputs, and gradients over specified periods.
- Arguments:
- savedir: str
- Path to directory where to save data and class instance state.
- configs: dict
Data mode-layer/weight name pairs for logging. Ex:
>>> configs = { ... 'weights': ['dense_2', 'conv2d_1/kernel:0'], ... 'outputs': 'lstm_1', ... 'gradients': ('conv2d_1', 'dense_2:/bias:0'), ... 'gradients-kw': dict(mode='weights', learning_phase=0), ... }
With
mode='weights'
for('weights', 'gradients')
, complete weight names may be specified - else, all layer weights will be included. If layer name is substring, will return earliest match. Seesee_rnn.inspect_gen
methods.- loadpath: str / None
- Path to savefile of class instance state, to resume logging.
If None, will attempt to fetch from
savedir
based onlogname
. - get_data_fn: function
- Will call to fetch input data to feed for
'gradients'
and'outputs'
data. Defaults tolambda: TrainGenerator.val_datagen.get()[0]
. - get_labels_fn: function
- Will call to fetch labels to feed for
'gradients'
data. Defaults tolambda: TrainGenerator.val_datagen.get()[1]
ifTrainGenerator.input_as_labels == False
- else, to~get()[0]
. - logname: str
- Base name of savefiles, prefixing save
_id
. - init_log_id: int / None
- Initial log
_id
, will increment by 1 with each call tolog()
. (unless_id
was passed tolog()
). Defaults to-1
.
-
init_with_traingen
(traingen)¶ Instantiates configs, getter functions, & others; requires a
TrainGenerator
instance.
-
log
(_id=None)¶ Gathers data according to
configs
andgather_fns
.
-
save
(_id=None, clear=False, verbose=1)¶ Saves data per
_loggables
, but not other class instance attributes.- Arguments:
- _id: int / str[int] / None
- Appended to
logname
to make savepath insavedir
. - clear: bool
- Whether to empty data attributes after saving.
- verbose: bool / int[bool]
- Whether to print a save message with savepath.
-
load
(verbose=1)¶ Loads from an .h5 file according to
loadpath
, which can include data and other class instance attributes. Ifloadpath
is None, will attempt to fetch fromsavedir
based onlogname
.verbose
==1
/True
to print load message with loadpath.
-
clear
(verbose=1)¶ Set attributes in
_loggables
to{}
deeptrain.data_generator module¶
-
class
deeptrain.data_generator.
DataGenerator
(data_path, batch_size=32, labels_path=None, preprocessor=None, preprocessor_configs=None, data_loader=None, labels_loader=None, preload_labels=None, shuffle=False, superbatch_path=None, set_nums=None, superbatch_set_nums=None, **kwargs)¶ Bases:
object
Central interface between a directory and
TrainGenerator
. Handles data loading, preprocessing, shuffling, and batching. Requires onlydata_path
to run.- Arguments:
- data_path: str
- Path to directory to load data from.
- batch_size: int
- Number of samples to feed the model at once. Can differ from size of batches of loaded files; see “Dynamic batching”.
- labels_path: str / None
- Path to labels file. If None, will not load
labels
; can be used withTrainGenerator.input_as_labels = True
, feedingbatch
aslabels
inTrainGenerator.get_data()
(e.g. autoencoders). - preprocessor: None / custom object / str in (‘timeseries’,)
Transforms
batch
andlabels
right before both are returned byget()
. See_set_preprocessor()
.- str: fetches one of API-supported preprocessors.
- None, uses
GenericPreprocessor()
. - Custom object: must subclass
Preprocessor()
.
- preprocessor_configs: None / dict
- Kwargs to pass to
preprocessor
in case it’s None, str, or an uninstantiated custom object. Ignored ifpreprocessor
is instantiated. - data_loader: None / function /
DataLoader()
Object for loading data from directory / file.
- function: passed as
loader
toDataLoader.__init__
in_infer_and_set_info()
; input signature:(self, set_num)
DataLoader()
instance: will setdata_loader
directly- Class subclassing
DataLoader()
(uninstantiated): will instantiate with attrs from_infer_info()
& others - str: name of one of loaders in
util.data_loaders
- None: defaults to one of defined in
util.data_loaders
, as determined by_infer_info()
- function: passed as
- labels_loader: None / function /
DataLoader()
data_loader
, but for labels.- preload_labels: bool / None
- Whether to load all labels into
all_labels
at__init__
. Defaults to True iflabels_path
is a file or a directory containing a single file. - shuffle: bool
- If True,
reset_state()
will shuffleset_nums_to_process
; the method is called by_on_epoch_end
withinTrainGenerator._train_postiter_processing()
andTrainGenerator._val_postiter_processing()
(viaon_epoch_end()
). - superbatch_path: str / None
- Path to file or directory from which to load
superbatch
(preload_superbatch()
); see “Control Flow”. - set_nums: list[int] / None
- Used to set
set_nums_original
andset_nums_to_process
. If None, will infer fromdata_path
; see “Control Flow”. - superbatch_set_nums: list[int] / None
- set_nums to load into
superbatch
; see “Control Flow”.
How it works:
Data is fed to
TrainGenerator()
viaDataGenerator()
. To work, data:- must be in one directory (or one file with all data)
- file extensions must be same (.npy, .h5, etc)
- file names must be enumerated with a common name (data1.npy, data2.npy, …)
- file batch size (# of samples, or dim 0 slices) should be same, but can also be in integer or fractal multiples of (x2, x3, x1/2, x1/3, …)
- labels must be in one file - unless feeding input as labels (e.g.
autoencoder), which doesn’t require labels files; just pass
TrainGenerator(input_as_labels=True)
Flexible batch_size:
Loaded file’s batch size may differ from
batch_size
, so long as former is an integer or integer fraction multiple of latter. Ex:len(loaded) == 32
,batch_size == 64
-> will load another file and concatenate intolen(batch) == 64
.len(loaded) == 64
,batch_size == 32
-> will set first half ofloaded
asbatch
and cacheloaded
, then repeat for second half.- ‘Technically’, files need not be integer (/ fraction) multiples, as
the following load order works with
batch_size == 32
:len(loaded) == 31
,len(loaded) == 1
- but this is not recommended, as it’s bound to fail if using shuffling, or if total number of samples isn’t divisible bybatch_size
. Other problems may also arise.
Control Flow:
set_num
: index / Mapping key used to identify and getbatch
andlabels
viaload_data()
andload_labels()
. Ex: forDataLoader.numpy_loader()
, which expects files shaped(batches, samples, *)
, it’d donp.load(path)[set_num]
.set_nums_to_process
: will popset_num
from this list until it’s empty; once empty, will setall_data_exhausted=True
(update_state()
).set_nums_original
: will resetset_nums_to_process
to this withreset_state()
. It’s eitherset_nums
passed to__init__
or is inferred in_set_set_nums()
as all availableset_nums
indata_path
file / directory.superbatch
: dict ofset_num
-batch`es loaded persistently in memory (RAM) as opposed to `batch
, which is overwritten. Once loaded,batch
can be drawn straight fromsuperbatch
ifset_num
is in it (i.e.superbatch_set_nums
).get()
returnsbatch
andlabels
fed throughPreprocessor.process()
.advance_batch()
gets “next”batch
andlabels
. “Next” is determined byset_num
, which is popped fromset_nums_to_process[0]
.batch_exhausted
: signalsTrainGenerator()
that a batch was consumed; this information is set viaupdate_state()
per_on_iter_end
withinTrainGenerator._train_postiter_processing()
orTrainGenerator._val_postiter_processing()
.- If using slices (
slices_per_batch is not None
), thenbatch_exhausted
is set to True only whenslice_idx == slices_per_batch - 1
.
__init__
:Instantiation. (“+” == if certain conditions are met)
- +Infers missing configs based on args
- Validates args & kwargs, and tries to correct, printing a”NOTE” or “WARNING” message where appropriate
- +Preloads all labels into
all_labels
- Instantiates misc internal parameters to predefiend values (may be
overridden by
TrainGenerator
loading).
-
_BUILTINS
= {'extensions': {'.csv', '.h5', '.npy'}, 'loaders': {'csv', 'hdf5', 'numpy', 'numpy-lz4f', 'numpy-memmap'}, 'preprocessors': (<class 'deeptrain.util.preprocessors.GenericPreprocessor'>, <class 'deeptrain.util.preprocessors.TimeseriesPreprocessor'>)}¶
-
get
(skip_validation=False)¶ Returns
(batch, labels)
fed toPreprocessor.process()
.- skip_validation: bool
- False (default): calls
_validate_batch()
, which willadvance_batch()
ifbatch_exhausted
, andreset_state()
ifall_data_exhausted
. - True: fetch preprocessed
(batch, labels)
without advancing any internal states.
- False (default): calls
-
advance_batch
(forced=False, is_recursive=False)¶ Sets next
batch
andlabels
; handles dynamic batching.- If
batch_loaded
and notforced
(and notis_recursive
), prints a warning that batch is loaded, and returns (does nothing) len(batch) != batch_size
:< batch_size
: callsadvance_batch()
withis_recursive = True
. With each such call,batch
andlabels
are extended (stacked) until matchingbatch_size
.> batch_size
, not integer multiple: raisesException
.> batch_size
, is integer multiple: makes_group_batch
and_group_labels
, which are used to setbatch
andlabels
.
- +If
set_nums_to_process
is empty, will raiseException
; it must have been reset beforehand via e.g.reset_state()
. If it’s not empty, setsset_num
by popping fromset_nums_to_process
. (+: only if_group_batch
is None) - Sets or extends
batch
via_get_next_batch()
(by loading, or from_group_batch
orsuperbatch
). - +Sets or extends
labels
via_get_next_labels()
(by loading, or from_group_labels
, orall_labels
). (+: only iflabels_path
is a path (and not None)) - Sets
set_name
, used byTrainGenerator()
to print iteration messages. - Sets
batch_loaded = True
,batch_exhausted = False
,all_data_exhausted = False
, andslice_idx
to None if it’s already None (else to0
).
- If
-
_get_next_batch
(set_num=None, warn=True)¶ Gets
batch
perset_num
.set_num = None
: will useself.set_num
.warn = False
: won’t print warning on superbatch not being preloaded.- If
_group_batch
is not None, will getbatch
from_group_batch
. - If
set_num
is insuperbatch_set_nums
, will getbatch
assuperbatch[set_num]
(ifsuperbatch
exists). - By default, gets
batch
viaload_data()
.
-
_get_next_labels
(set_num=None)¶ Gets
labels
perset_num
.set_num = None
: will useself.set_num
.- If
_group_labels
is not None, will getlabels
from_group_labels
. - If
set_num
is insuperbatch_set_nums
, will getbatch
assuperbatch[set_num]
(ifsuperbatch
exists). - By default, gets
labels
viaload_data()
, iflabels_path
is set - else,labels=[]
.
-
_batch_from_group_batch
()¶ Slices
_group_batch
perbatch_size
and_group_batch_idx
.
-
_labels_from_group_labels
()¶ Slices
_group_labels
perbatch_size
and_group_batch_idx
.
-
_update_group_batch_state
()¶ Sets “group” attributes to
None
once sufficient number of batches were extracted, else increments_group_batch_idx
.
-
on_epoch_end
()¶ Increments
epoch
, callspreprocessor.on_epoch_end(epoch)
, thenreset_state()
, and returnsepoch
.
-
update_state
()¶ Calls
preprocessor.update_state()
, and ifbatch_exhausted
andset_nums_to_process == []
, setsall_data_exhausted = True
to signalTrainGenerator()
of epoch end.
-
reset_state
(shuffle=None)¶ Calls
preprocessor.reset_state()
, setsbatch_exhausted = True
,batch_loaded = False
, resetsset_nums_to_process
toset_nums_original
, and shufflesset_nums_to_process
ifshuffle
.- If
shuffle
passed in is None, will set fromself.shuffle
. - Used in
TrainGenerator.reset_validation()
w/shuffle=False
.
- If
-
_validate_batch
()¶ If
all_data_exhausted
, callsreset_state()
. Ifbatch_exhausted
, callsadvance_batch()
.
-
_make_group_batch_and_labels
(n_batches)¶ Makes
_group_batch
and_group_labels
when loadedlen(batch)
exceedsbatch_size
as its integer multiple. May shuffle._group_batch = np.asarray(batch)
, and_group_labels = np.asarray(labels)
; each’slen() > batch_size
.- Shuffles if:
shuffle_group_samples
: shuffles all samples (dim0 slices)shuffle_group_batches
: groups dim0 slices bybatch_size
, then shuffles the groupings. Ex:>>> batch_size == 32 >>> batch.shape == (128, 100) >>> batch = batch.reshape() # (4, 32, 100) == .shape >>> shuffle(batch) # 24 (4!) permutations >>> batch = batch.reshape() # (128, 100) == .shape
Sets
_group_batch_idx = 0
, and calls_update_group_batch_state()
.Doesn’t affect
labels
iflabels_path
is falsy (e.g. None)
-
batch_exhausted
¶ Is retrieved from and set in
preprocessor
. Indicates thatbatch
andlabels
for givenset_num
were consumed byTrainGenerator
(if using slices, that all slices were consumed).Ex:
self.batch_exhausted = 5
will setself.preprocessor.batch_exhausted = 5
, andprint(self.batch_exhausted)
will then print5
(or something else ifpreprocessor
changes it internally).
-
batch_loaded
¶ Is retrieved from and set in
preprocessor
, same asbatch_exhausted
. Indicates thatbatch
andlabels
for givenset_num
are loaded.
-
slices_per_batch
¶ Is retrieved from and set in
preprocessor
, same asbatch_exhausted
.
-
slice_idx
¶ Is retrieved from and set in
preprocessor
, same asbatch_exhausted
.
-
load_data
¶ Load and return
batch
data viadata_loaders.DataLoader.load_fn()
. Used by_get_next_batch()
andpreload_superbatch()
.
-
load_labels
¶ Load and return
labels
data viadata_loaders.DataLoader.load_fn()
. Used by_get_next_labels()
andpreload_all_labels()
.
-
_infer_info
(path)¶ Infers unspecified essential attributes from directory and contained files info:
- Checks that the data directory (
path
) isn’t empty (files whose names start with'.'
aren’t counted) - Retrieves data filepaths per
path
and gets data extension (to most frequent ext in dir, excl. “other path” from count if in same dir. “other path” isdata_path
ifpath == labels_path
, and vice versa.) - Gets
base_name
as longest common substring among files withext
extension - If
path
is path to a file, thenfilepaths=[path]
. - If
path
is None, returnsbase_name=None
ext=None
,filepaths=[]
.
- Checks that the data directory (
-
_infer_and_set_info
(data_loader, labels_loader)¶ Sets
data_loader
andlabels_loader
(DataLoader()
), using info obtained from_infer_info()
.- If
info
contains only one filepath, loader will operate with_is_dataset=True
. - If
preload_labels
is None andlabels_loader._is_dataset
, will setpreload_labels=True
.
data_loader
/labels_loader
are:- function: passed as
loader
toDataLoader.__init__
; input signature:(self, set_num)
DataLoader()
instance: will setdata_loader
directly- Class subclassing
DataLoader()
(uninstantiated): will instantiate with attrs from_infer_info()
& others - str: name of one of loaders in
util.data_loaders
- None: defaults to one of defined in
util.data_loaders
, as determined by_infer_info()
- If
-
_set_set_nums
(set_nums, superbatch_set_nums)¶ Sets
set_nums_original
,set_nums_to_process
, andsuperbatch_set_nums
.- Fetches
set_nums
viaDataLoader._get_set_nums()
- Sets
set_nums_to_process
andset_nums_original
; ifset_nums
weren’t passed to__init__
, sets to fetched ones. - If
set_nums
were passed, validates that they’re a subset of fetched ones (i.e. can be seen bydata_loader
). - Sets
superbatch_set_nums
; if not passed to__init__
, and== 'all'
, sets to fetched ones. If passed, validates that they subset fetched ones. - Does not validate set_nums from
labels_loader
’s perspective; user is expected to supply alabels
to eachbatch
with commonset_num
.
- Fetches
-
_set_preprocessor
(preprocessor, preprocessor_configs)¶ Sets
preprocessor
, based onpreprocessor
passed to__init__
:- If None, sets to
GenericPreprocessor()
, instantiated withpreprocessor_configs
. - If an uninstantiated class, will validate that it subclasses
Preprocessor()
, then isntantiate withpreprocessor_configs
. - If string, will match to a supported builtin.
- Validates that the set
preprocessor
subclassesPreprocessor()
.
- If None, sets to
-
preload_superbatch
()¶ Loads all data specified by
superbatch_set_nums
viaload_data()
, and assigns them tosuperbatch
for eachset_num
.
-
preload_all_labels
()¶ Loads all labels into
all_labels
usingload_labels()
, based onset_nums_original
.
-
_init_and_validate_kwargs
(kwargs)¶ Sets and validates
kwargs
passed to__init__
.- Ensures
data_path
is a file or a directory, andlabels_path
is a file, directory, or None. - Ensures kwargs are functional (compares against names in
_DEFAULT_DATAGEN_CFG
. - Sets whichever names were passed with
kwargs
, and defaults the rest.
- Ensures
-
_init_class_vars
()¶ Instantiates various internal attributes. Most of these are saved and loaded by
TrainGenerator()
by default.
deeptrain.introspection module¶
-
deeptrain.introspection.
compute_gradient_norm
(self, input_data, labels, sample_weight=None, learning_phase=0, _id='*', mode='weights', norm_fn=(<ufunc 'sqrt'>, <ufunc 'square'>), scope='local')¶ Computes gradients w.r.t. layer weights or outputs per
_id
, and returns norm according tonorm_fn
andscope
.- Arguments:
- input_data: np.ndarray / list[np.ndarray] / supported formats
- Data w.r.t. which loss is to be computed for the gradient.
List of arrays for multi-input networks. “Supported formats”
is any valid input to
model
. - labels: np.ndarray / list[np.ndarray] / supported formats
- Labels w.r.t. which loss is to be computed for the gradient.
- sample_weight: np.ndarray / list[np.ndarray] / supported formats
- kwarg to
model.fit()
, etc., weighting individual sample losses. - learning_phase: bool / int[bool]
- 1: use model in train mode
- 0: use model in inference mode
- _id: str / int / list[str/int].
- int -> idx; str -> name
- idx: int. Index of layer to fetch, via model.layers[idx].
- name: str. Name of layer (full or substring) to be fetched. Returns earliest match if multiple found.
- list[str/int] -> treat each str element as name, int as idx.
Ex:
['gru', 2]
gets (e.g.) weights of first layer with name substring ‘gru’, then of layer w/ idx 2. '*'
(wildcard) -> get (e.g.) outputs of all layers (except input) with ‘output’ attribute.
- mode: str in (‘weights’, ‘outputs’, ‘gradients:weights’, ‘gradients:outputs’)
- Whether to fetch layer weights, outputs, or gradients (w.r.t. outputs or weights).
- norm_fn: (function, function) / function
- Norm function(s) to apply to gradients arrays when gathering.
(np.sqrt, np.square)
for L2-norm,np.abs
for L1-norm. Computed as:outer_fn(sum(inner_fn(x) for x in data))
, whereouter_fn, inner_fn = norm_fn
ifnorm_fn
is list/tuple, andinner_fn = norm_fn
andouter_fn = lambda x: x
otherwise. - scope: str in (‘local’, ‘global’)
- Whether to apply
stat_fn
on individual gradient arrays, or sum of.
- Returns:
- Gradient norm(s). List of float if
scope == 'local'
(norms of weights), else float (outer_fn(sum(sum(inner_fn(g)) for g in grads))
).
TensorFlow optimizers do gradient clipping according to the
clipnorm
setting by comparing individual weights’ L2-norms againstclipnorm
, and rescaling if exceeding. These L2 norms can be obtained usingnorm_fn=(np.sqrt, np.square)
withscope == 'local'
andmode='weights'
. See:tensorflow.python.keras.optimizer_v2.optimizer_v2._clip_gradients
keras.optimizers.clip_norm
tensorflow.python.ops.clip_ops.clip_by_norm
-
deeptrain.introspection.
gradient_norm_over_dataset
(self, val=False, learning_phase=0, mode='weights', norm_fn=(<ufunc 'sqrt'>, <ufunc 'square'>), stat_fn=<function median>, n_iters=None, prog_freq=10, w=1, h=1)¶ Aggregates gradient norms over dataset, one iteration at a time. Useful for estimating value of gradient clipping,
clipnorm
, to use. Plots a histogram of gathered data when finished. Also seecompute_gradient_norm()
.- Arguments:
- val: bool
- True: gather over
val_datagen
batches - False: gather over
datagen
batches
- True: gather over
- learning_phase: bool / int[bool]
- True: get gradients of model in train mode
- False: get gradients of model in inference mode
- mode: str in (‘weights’, ‘outputs’)
- Whether to get gradients with respect to layer weights or outputs.
- norm_fn: (function, function) / function
- Norm function(s) to apply to gradients arrays when gathering.
(np.sqrt, np.square)
for L2-norm,np.abs
for L1-norm. Computed as:outer_fn(sum(inner_fn(g) for g in grads))
, whereouter_fn, inner_fn = norm_fn
ifnorm_fn
is list/tuple, andinner_fn = norm_fn
andouter_fn = lambda x: x
otherwise. - stat_fn: function
- Aggregate function to apply on computed norms. If
np.mean
, will gather mean of gradients; ifnp.median
, the median, etc. Computed as:stat_fn(outer_fn(sum(inner_fn(g) for g in grads)))
. - n_iters: int / None
- Number of expected iterations over entire dataset. Can be used to
iterate over subset of entire dataset. If None, will return upon
DataGenerator.all_data_exhausted
. - prog_freq: int
- How often to print
f'|{batch_idx}'
, and'.'
otherwise, in terms of number of batches (not iterations, but are same if not using slices). E.g. 5:....|5....|10....|15
. - w, h: float
- Scale figure width & height, respectively.
- Returns:
- grad_norms: np.ndarray
- Norms of gradients for every iteration.
Shape:
(iters_processed, n_params)
, wheren_params
is number of gradient arrays whose norm stats were computed at each iteration. - batches_processed: int
- Number of batches processed.
- iters_processed: int
- Number of iterations processed (if using e.g. 4 slices per batch,
will equal
4 * batches_processed
).
-
deeptrain.introspection.
gradient_sum_over_dataset
(self, val=False, learning_phase=0, mode='weights', n_iters=None, prog_freq=10, plot_kw={})¶ Computes cumulative sum of gradients over dataset, one iteration at a time, preserving full array shapes. Useful for computing mean of gradients over dataset, or other aggregate metrics.
- Arguments:
- val: bool
- True: gather over
val_datagen
batches - False: gather over
datagen
batches
- True: gather over
- learning_phase: bool / int[bool]
- True: get gradients of model in train mode
- False: get gradients of model in inference mode
- mode: str in (‘weights’, ‘outputs’)
- Whether to get gradients with respect to layer weights or outputs.
- n_iters: int / None
- Number of expected iterations over entire dataset. Can be used to
iterate over subset of entire dataset. If None, will return upon
DataGenerator.all_data_exhausted
. - prog_freq: int
- How often to print
f'|{batch_idx}'
, and'.'
otherwise, in terms of number of batches (not iterations, but are same if not using slices). E.g. 5:....|5....|10....|15
. - plot_kw: dict
- Kwargs to pass to
see_rnn.features_hist
; defaults to{'share_xy': False, 'center_zero': True}
.
- Returns:
- grad_sum: dict[str: np.ndarray]
- Gradient arrays summed over dataset. Structure:
{name: array, name: array, ...}
, wherename
is name of weight array or layer output. - batches_processed: int
- Number of batches processed.
- iters_processed: int
- Number of iterations processed (if using e.g. 4 slices per batch,
will equal
4 * batches_processed
).
-
deeptrain.introspection.
_gather_over_dataset
(self, gather_fn, val=False, n_iters=None, prog_freq=10)¶ Iterates over
DataGenerator
, applyinggather_fn
to every batch (or slice). Stops aftern_iters
, or whenDataGenerator.all_data_exhausted
ifn_iters is None
. Useful for monitoring quantities over the course of training or inference,.gather_fn
recursively updatesdata
; as such, it can be used to append to a list, update a dictionary, operate on an array, etc. Review source code for exact logic.
-
deeptrain.introspection.
_make_gradients_fn
(model, learning_phase, mode, return_names=False)¶ Makes reusable gradient-getter function, separately for TF Eager & Graph execution. Eager variant is pseudo-reusable; gradient tensors are still fetched all over - graph should be significantly faster.
-
deeptrain.introspection.
print_dead_weights
(model, dead_threshold=1e-07, notify_above_frac=0.001, notify_detected_only=False)¶ Print names of dead weights and their proportions. Useful for debugging vanishing and exploding gradients, or quantifying sparsity.
- Arguments:
- model: models.Model / models.Sequential (keras / tf.keras)
- The model.
- dead_threshold: float
- Threshold below which to count the weight as “dead”, in absolute value.
- notify_above_frac: float
- Print only if fraction of weights counted “dead” exceeds this
(e.g. if there are 11 absolute values <
dead_threshold
out of 1000). - notify_detected_only: bool
- True: print text only if dead weights are discovered
- False: print a “not found given thresholds” message when appropriate
-
deeptrain.introspection.
print_nan_weights
(model, notify_detected_only=False)¶ Print names of NaN/Inf weights and their proportions. Useful for debugging exploding or buggy gradients.
- Arguments:
- model: models.Model / models.Sequential (keras / tf.keras)
- The model.
- notify_detected_only: bool
- True: print text only if dead weights are discovered
- False: print a “none found” message if no NaNs were found
-
deeptrain.introspection.
print_large_weights
(model, large_threshold=3, notify_above_frac=0.001, notify_detected_only=False)¶ Print names of weights in excess of set absolute value, and their proportions; excludes Inf. Useful for debugging exploding or buggy gradients.
- Arguments:
- model: models.Model / models.Sequential (keras / tf.keras)
- The model.
- large_threshold: float
- Threshold above which to count the weight’s absolute value as “large”.
- notify_above_frac: float
Print only if fraction of weights counted “large” exceeds this (e.g. if there are 11 absolute values <
large_threshold
out of 1000).- notify_detected_only: bool
- True: print text only if dead weights are discovered
- False: print a “none found” message if no NaNs were found
-
deeptrain.introspection.
_sample_weight_built
(model)¶ In Graph execution,
model._feed_sample_weights
isn’t built unless model is compiled withsample_weight_mode
set, ortrain_on_batch
ortest_on_batch
is called w/sample_weight
passed.
-
deeptrain.introspection.
interrupt_status
(self) -> (<class 'bool'>, <class 'bool'>)¶ Prints whether
TrainGenerator
was interrupted (e.g.KeyboardInterrupt
, or via exception) duringtrain()
andvalidate()
. Returns bools (True for interrupted, else False) for each, as (train, val).Not foolproof; user can set flags manually or via callbacks. For further assurance, check
temp_history
,val_temp_history
, and cache attributes (e.g._preds_cache
) which are cleared at end ofvalidate()
by default; this method checks only flags:_train_loop_done
,train_postiter_processed
,_val_loop_done
,_val_postiter_processed
.
-
deeptrain.introspection.
info
(self)¶ Prints various useful TrainGenerator & DataGenerator attributes, and interrupt status.
deeptrain.metrics module¶
-
deeptrain.metrics.
f1_score
(y_true, y_pred, pred_threshold=0.5, beta=1)¶
-
deeptrain.metrics.
mean_squared_error
(y_true, y_pred, sample_weight=1)¶
-
deeptrain.metrics.
mean_absolute_error
(y_true, y_pred, sample_weight=1)¶
-
deeptrain.metrics.
roc_auc_score
(y_true, y_pred)¶
deeptrain.preprocessing module¶
-
deeptrain.preprocessing.
data_to_hdf5
(savepath, batch_size, loaddir=None, data=None, shuffle=False, compression='lzf', dtype=None, load_fn=None, oversample_remainder=True, batches_dim0=False, overwrite=None, verbose=1)¶ Convert data to hdf5-group (.h5) format, in
batch_size
sample sets.- Arguments:
- savepath: str
- Absolute path to where to save file.
- batch_size: int
- Number of samples (dim0 slices) to save per file.
- loaddir: str
- Absolute path to directory from which to load data.
- data: np.ndarray / list[np.ndarray]
- Shape:
(samples, *)
or(batches, samples, *)
(must usebatches_dim0=True
. With former, iflen(data) == 320
andbatch_size == 32
, will make a 10-set .h5 file. - shuffle: bool
- Whether to shuffle samples (dim0 slices).
- compression: str
- Compression type to use. kwarg to
h5py.File().create_dataset()
. - dtype: str / np.dtype
- Savefile dtype; kwarg to
.create_dataset()
. Defaults to data’s dtype. - load_fn: function / callable
- Used on supported paths (.npy) in
loaddir
to load data. - oversample_remainder: bool. Relevant only when passing
data
. - True -> randomly draw (remainer - batch_size) samples to fill incomplete batch.
- False -> drop remainder samples.
- batches_dim0: bool
- Assume shapes - True:
(batches, samples, *)
; False:(samples, *)
. - overwrite: bool / None
If
savepath
file exists,- True -> replace it
- False -> don’t replace it
- None -> ask confirmation via user input
- verbose: bool
- Whether to print preprocessing progress.
Notes:
- If supplying
loaddir
instead ofdata
, will iteratively load files with supported format (.npy).len()
of loaded file must be an integer fraction multiple ofbatch_size
, <= 1. Sobatch_size == 32
andlen() == 16
works, butlen() == 48
orlen() == 24
doesn’t.
-
deeptrain.preprocessing.
numpy_to_lz4f
(data, savepath=None, level=9, overwrite=None)¶ Do lz4-framed compression on
data
. (Install compressor via!pip install py-lz4framed
)- Arguments:
- data: np.ndarray
- Data to compress.
- savepath: str
- Path to where to save file.
- level: int
- 1 to 9; higher = greater compression
- overwrite: bool
If
savepath
file exists,- True -> replace it
- False -> don’t replace it
- None -> ask confirmation via user input
- Returns:
- np.ndarray - compressed array.
Example:
>>> numpy_to_lz4f(savedata, savepath=path) ... >>> # load & decompress >>> bytes_npy = lz4f.decompress(np.load(path)) >>> loaddata = np.frombuffer(bytes_npy, ... dtype=savedata.dtype, # must be original's ... ).reshape(*savedata.shape)
-
deeptrain.preprocessing.
numpy_data_to_numpy_sets
(data, labels, savedir=None, batch_size=32, shuffle=True, data_basename='batch', oversample_remainder=True, overwrite=None, verbose=1)¶ Save
data
inbatch_size
chunks, possibly shuffling samples.- Arguments:
- data: np.ndarray
- Data to batch along
labels
& save. - labels: np.ndarray
- Labels to batch along
data
and save. - savedir: str / None
- Directory in which to save processed data. If None, won’t save.
- batch_size: int
- Number of samples (dim0 slices) to form a ‘set’ with
- data_basename: str
- Will save with this prepending set numbering - e.g.: ‘batch__1.npy’, ‘batch__2.npy’ …
- oversample_remainder: bool
- True -> randomly draw (remainer - batch_size) samples to fill incomplete batch.
- False -> drop remainder samples.
- overwrite: bool / None
If
savepath
file exists,- True -> replace it
- False -> don’t replace it
- None -> ask confirmation via user input
- verbose: bool
- Whether to print preprocessing progress.
- Returns:
- data, labels: processed
data
&labels
.
-
deeptrain.preprocessing.
numpy2D_to_csv
(data, savepath=None, batch_size=None, columns=None, sample_dim=1, overwrite=None)¶ Save 2D data as .csv.
- Arguments:
- data: np.ndarray
- Data to save, shaped
(batches, samples)
. - savepath: str
- Path to where save to file.
- batch_size: int
- Number of rows per column; can differ from
data
’s. - columns: list of str
- Column names for the data frame; defaults to enumerate columns (0, 1, 2, …)
- sample_dim: int
- Dimension applicable to
batch_size
. - overwrite: bool
If
savepath
file exists,- True -> replace it
- False -> don’t replace it
- None -> ask confirmation via user input
Example:
Suppose we have
labels
, 16 per batch, and 8 batches. Each batch is shaped (16,) - stacked, (8, 16). Dim 1 is thussample_dim
. Also see examples/preprocessing/timeseries.py.>>> data.shape == (8, 16) # (num_batches, samples) >>> numpy2D_to_csv(data, "data.csv", batch_size=32, batch_dim=1) >>> # if it was (samples, num_batches), sample_dim would be 0. ... # This will make a DataFrame of 4 columns, 32 rows per column, ... # reshaping `data` to correctly concatenate samples from dim0 to dim1.
deeptrain.train_generator module¶
-
class
deeptrain.train_generator.
TrainGenerator
(model, datagen, val_datagen, epochs=1, logs_dir=None, best_models_dir=None, loadpath=None, callbacks=None, fit_fn='train_on_batch', eval_fn='evaluate', key_metric='loss', key_metric_fn=None, val_metrics=None, custom_metrics=None, input_as_labels=False, max_is_best=None, val_freq={'epoch': 1}, plot_history_freq={'epoch': 1}, unique_checkpoint_freq={'epoch': 1}, temp_checkpoint_freq=None, class_weights=None, val_class_weights=None, reset_statefuls=False, iter_verbosity=1, logdir=None, optimizer_save_configs=None, optimizer_load_configs=None, plot_configs=None, model_configs=None, **kwargs)¶ Bases:
deeptrain.util._traingen_utils.TraingenUtils
The central DeepTrain class. Interfaces training, validation, checkpointing, data loading, and progress tracking.
- Arguments:
- model: models.Model / models.Sequential [keras / tf.keras]
- Compiled model to train.
- datagen:
DataGenerator
- Train data generator; fetches inputs and labels, handles preprocessing, shuffling, stateful formats, and informing TrainGenerator when a dataset is exhausted (epoch end).
- val_datagen:
DataGenerator
- Validation data generator.
- epochs: int
- Number of train epochs.
- logs_dir: str / None
- Path to directory where to generate log directories, that include
TrainGenerator state, state report, model data, and others; see
checkpoint()
. IfNone
, will not checkpoint - but model saving still possible viabest_models_dir
. - best_models_dir: str / None
- Path to directory where to save best model. “Best” means having new
highest (
max_is_best==True
) or lowest (max_is_best==False
) entry inkey_metric_history
. See_save_best_model()
. - loadpath: str / None
- Path to .h5 file containing TrainGenerator state to load (postfixed
'__state.h5'
by default). Seeload()
. - callbacks: dict[str: function] /
TraingenCallback
/ None - Functions to apply at various stages, including training, validation,
saving, loading, and
__init__
. SeeTraingenCallback
. - fit_fn: str /
function(x, y, sample_weight)
- Function, or name of model method to feed data to during training;
if str, will define
fit_fn = getattr(model, 'fit')
(example). If function, its name (substring) must include'fit'
or'train'
(currently both function identically). - eval_fn: str /
function(x, y, sample_weight)
Function, or name of model method to feed data to during validation; if str, will define
eval_fn = getattr(model, 'evaluate')
(example). If function, its name (substring) must include'evaluate'
or'predict'
:'evaluate'
:eval_fn
uses data & labels to return metrics.'predict'
:eval_fn
uses data to return predictions, which are used internally to compute metrics.
- key_metric: str
- Name of metric to track for saving best model; will store in
key_metric_history
. See_save_best_model()
. - key_metric_fn: function / None
- Custom function to compute key metric; overrides
key_metric
if not None. - val_metrics: list[str] / None
Names of metrics to track during validation.
- If
'predict'
is not ineval_fn.__name__
, is overridden by model metrics (model.compile(metrics=...)
) - If
'loss'
is not included, will prepend. - If
'*'
is included, will insert model metrics at its position and pop'*'
. Ex:[*, 'f1_score']
->['loss', 'accuracy', 'f1_score']
.
- If
- custom_metrics: dict[str: function]
Name-function pairs of custom functions to use for gathering metrics. Functions must obey
(y_true, y_pred)
input signature for first two arguments. They may additionally supplysample_weight
andpred_threshold
, which will be detected and used automatically.- Note: if using a custom metric in
model.compile(loss=tf_fn)
, name incustom_metrics
must be function’s code name, i.e.{tf_fn.__name__: fn}
(wherefn
is e.g. numpy version).
- Note: if using a custom metric in
- input_as_labels: bool
- Feed model input also to its output. Ex: autoencoders.
- max_is_best: bool
- Whether to consider greater
key_metric
as better in saving best model. See_save_best_model()
. If None, defaults to False ifkey_metric=='loss'
, else True. - val_freq: None / dict[str: int], str in {‘iter’, ‘batch’, ‘epoch’, ‘val’}
- How frequently to validate.
{'epoch': 1}
-> every epoch;{'iter': 24}
-> every 24 train iterations. Only one key-value pair supported. If None, won’t validate. - plot_history_freq: None / dict[str: int], str in {‘iter’, ‘batch’, ‘epoch’, ‘val’}
- How frequently to plot train & validation history. Only one key-value pair supported. If None, won’t plot history.
- unique_checkpoint_freq: None / dict[str: int], str in {‘iter’, ‘batch’, ‘epoch’, ‘val’}
- How frequently to make checkpoints with unique savefile names, as opposed to temporary ones which are overwritten each time. Only one key-value pair supported. If None, won’t make unique checkpoints.
- temp_checkpoint_freq: None / dict[str: int], str in {‘iter’, ‘batch’, ‘epoch’, ‘val’}
- How frequently to make checkpoints with the same predefined name, to be overwritten (“temporary”); serves as an intermediate checkpoint to unique ones, if needed. Only one key-value pair supported. If None, won’t make temporary checkpoints.
- class_weights: dict[int: int] / None
Integer-mapping of class labels to their “weights”; if not None, will feed
sample_weight
mediated by the weights to train function (fit_fn
).>>> class_weights = {0: 4, 1: 1} >>> labels == [1, 1, 0, 1] # if >>> sample_weight == [4, 4, 1, 4] # then
- val_class_weights: dict[int: int] / None
class_weights
for validation function (eval_fn
).- reset_statefuls: bool
- Whether to call
model.reset_states()
at the end of every batch (train and val). - iter_verbosity: int in {0, 1, 2}
- 0: print no iteration info
- 1: print name of set being fit / validated, metric names and values,
- and
model.reset_states()
being called
- 2: print a
'.'
at every iteration (useful if having multiple iterations per batch)
- logdir: str / None
- Directory where to write logs to (see
logs_dir
). Use to specify an existing directory (to, for example, resume training and logging in original folder). Overrideslogs_dir
. - optimizer_save_configs: dict / None
- Dict specifying which optimizer attributes to include or exclude
when saving. See
save()
. - optimizer_load_configs: dict / None
- Dict specifying which optimizer attributes to include or exclude
when loading. See
load()
. - plot_configs: dict / None
- Dict specifying
get_history_fig()
behavior. See_DEFAULT_PLOT_CFG
, and_make_plot_configs_from_metrics()
. - model_configs: dict / None
Dict specifying model information. Intended usage is: create
model
according to the dict, specifying hyperparameters, loss function, etc.>>> def make_model(batch_shape, units, optimizer, loss): ... ipt = Input(batch_shape=batch_shape) ... out = Dense(units)(ipt) ... model = Model(ipt, out) ... model.compile(optimizer, loss) ... return model ... >>> model_configs = {'batch_shape': (32, 16), 'units': 8, ... 'optimizer': 'adam', 'loss': 'mse'} >>> model = make_model(**model_configs)
Checkpoints will include an image report with the entire dict; the larger the portion of the model that’s created according to
model_configs
, the more will be documented for easy reference.- kwargs: keyword arguments.
- See
_DEFAULT_TRAINGEN_CFG
.kwargs
and all other arguments are subject to validation and correction by_validate_traingen_configs()
.
__init__
:Instantiation. (“+” == if certain conditions are met)
- Pulls methods from
TraingenUtils
- Validates args & kwargs, and tries to correct, printing a”NOTE” or “WARNING” message where appropriate
- +Instantiates logging directory
- +Loads
TrainGenerator
,datagen
, andval_datagen
states - +Loads model and optimizer weights (but not model architecture)
- +Preloads train & validation data (before a call to
train()
is made). - +Applies initial callbacks
- +Logs initial state (
_log_init_state()
) - Captures and saves all arguments passed to
__init__
- Instantiates misc internal parameters to predefiend values (may be overridden by loading).
-
train
()¶ The train loop.
- Fetches data from
get_data
- Fits data via
fin_fn
- Processes fit metrics in
_train_postiter_processing
- Stores metrics in
history
- Applies
'train:iter'
,'train:batch'
, and'train:epoch'
callbacks - Calls
validate
when appropriate
Interruption:
- Safe: during
get_data
, which can be called indefinitely without changing any attributes. - Avoid: during
_train_postiter_processing
, wherefit_fn
is applied and weights are updated - but metrics aren’t stored, and_train_postiter_processed=False
, restarting the loop without recording progress. - Best bet is during
validate()
, asget_data
may be too brief.
- Fetches data from
-
validate
(record_progress=True, clear_cache=True, restart=False, use_callbacks=True)¶ Validation loop.
- Fetches data from
get_data
- Applies function based on
_eval_fn_name
- Processes and caches metrics/predictions in
_val_postiter_processing
- Applies
'val:iter'
,'val:batch'
, and'val:epoch'
callbacks - Calls
_on_val_end
at end of validation to compute metrics and store them inval_history
- Applies
'val_end'
and maybe('val_end': 'train:epoch')
callbacks - If
restart
, callsreset_validation()
.
- Arguments:
- record_progress: bool
- If False, won’t update
val_history
,_val_iters
,_batches_validated
. - clear_cache: bool
- If False, won’t call
clear_cache()
; useful for keeping preds & labels acquired during validation. - restart: bool
- If True, will call
reset_valiation()
before validation loop to reset validation attributes; useful for starting afresh (e.g. if interrupted). - use_callbacks: bool
- If False, won’t call
apply_callbacks()
orplot_history()
.
Interruption:
- Safe: during
get_data
, which can be called indefinitely without changing any attributes. - Avoid: during
_val_postiter_processing
. Model remains unaffected*, but caches are updated; a restart may yield duplicate appending, which will error or yield inaccuracies. (* forward pass may consume random seed if random ops are used) - In practice: prefer interrupting immediately after
_print_iter_progress
executes.
- Fetches data from
-
_train_postiter_processing
(metrics)¶ Procedures done after every train iteration. Similar to
_val_postiter_processing()
, except operating on train rather than val variables, and callingvalidate()
when appropriate.
-
_val_postiter_processing
(record_progress=True, use_callbacks=True, metrics=None, batch_size=None)¶ Procedures done after every validation iteration. Unless marked “always”, are conditional and may skip.
- Update temp val history (always)
- Update
val_datagen
state (always) - Update val cache (preds, labels, etc)
- Update val history
- Reset statefuls
- Print progress
- Apply callbacks
Executes internal “callbacks” when appropriate:
_on_iter_end
,_on_batch_end
,_on_epoch_end
. List not exhaustive.
-
_on_val_end
(record_progress, use_callbacks, clear_cache)¶ Procedures done after
validate()
. Unless marked “always”, are conditional and may skip. List not exhaustive.- Update train/val history
- Clear cache
- Plot history
- Checkpoint
- Apply callbacks
- Check model health
- Validate
batch_size
(always) - Reset validation flags:
_inferred_batch_size
,_val_loop_done
,_train_loop_done
(always)
-
get_data
(val=False)¶ Get train (
val=False
) or validation (val=True
) data fromdatagen
orval_datagen
, respectively. SeeDataGenerator
.DataGenerator.get()
returnsx, labels
; ifinput_as_data == True
, setsy = x
- else,y = labels
. Either way, setsclass_labels = labels
. Generatessample_weight
fromclass_labels
.
-
clear_cache
(reset_val_flags=False)¶ Call to reset cache attributes accumulated during validation; useful for “restarting” validation (before calling
validate()
).Attributes set to
[]
:{'_preds_cache', '_labels_cache', '_sw_cache', '_class_labels_cach', '_set_name_cache', '_val_set_name_cach', '_y_true', '_val_sw'}
.
-
reset_validation
()¶ Used to restart validation (e.g. in case interrupted); calls
clear_cache()
andDataGenerator.reset_state()
(and, ifreset_statefuls
,model.reset_states()
).Does not reset validation counters (e.g.
_val_iters
).
-
_should_do
(freq_config, forced=False)¶ Checks whether a counter meets a frequency as specified in
val_freq
,plot_history_freq
,unique_checkpoint_freq
,temp_checkpoint_freq
.“Counter” is one of
_fit_iters
,_batches_fit
,epoch
, and_times_validated
. Ex: withunique_checkpoint_freq = {'batch': 5}
,checkpoint()
will make a unique checkpoint on every 5th batch fitted duringtrain()
.
-
_update_val_iter_cache
()¶ Called by
_on_iter_end
within_val_postiter_processing()
; updates validation cache variables (_labels_cache
,_preds_cache
,_class_labels_cache
,_sw_cache
).If
val_datagen
has a non-Noneslice_idx
, will preserve batch-slice structure:>>> [[y00, y01, y02], [y10, y11, y12]] # 2 batches, 3 slices/batch >>> [[y00, y01], [y10, y11], [y20, y21]] # 3 batches, 2 slices/batch
-
_print_train_progress
()¶ Called within
_train_postiter_processing()
, byon_batch_end()
.
-
_print_val_progress
()¶ Called within
_val_postiter_processing()
, byon_batch_end()
.
-
_print_progress
(metrics, endchar='\n')¶ Called by
_print_train_progress()
and_print_val_progress()
.
-
_print_iter_progress
(val=False)¶ Called within
train()
andvalidate()
.
-
plot_history
(update_fig=True, w=1, h=1)¶ Plots train & validation history (from
history
andval_history
).update_fig=True
-> store latest fig in_history_fig
.w
&h
scale the width & height, respectively, of the figure.- Plots configured by
plot_configs
.
-
_apply_callbacks
(stage)¶ Callbacks. See examples/callbacks
- Two approaches:
Class-based: inherit deeptrain.callbacks.TraingenCallback, define stage-based methods, e.g. on_train_epoch_end. Methods also take
stage
argument for further control, e.g. to only callon_train_epoch_end
whenstage == ('val_end', 'train:epoch')
.Function-based: make a dict of stage-function call pairs, e.g.:
>>> {'train:epoch': (fn1, fn2), ... 'val_batch': fn3, ... ('val_end': 'train:epoch'): fn4}
Callback will execute if a key is
in
thestage
passed to_apply_callbacks
; e.g.(fn1, fn2)
will execute onstage==('val_end', 'train:epoch')
, with key'train:epoch'
, butfn4
won’t execute, onstage=='train:epoch'
.
-
_init_callbacks
()¶ Instantiates callback objects (must subclass
TraingenCallback
), passing inTrainGenerator
instance as first (and only) argument. Enables custom callbacks utilizingTrainGenerator
attributes and methods.
-
check_health
(dead_threshold=1e-07, dead_notify_above_frac=0.001, large_threshold=3, large_notify_above_frac=0.001, notify_detected_only=True)¶ Check whether any layer weights have ‘zeros’ or NaN weights; very fast / inexpensive.
- Arguments:
- dead_threshold: float
- Count values below this as zeros.
- dead_notify_above_frac: float
- If fraction of values exceeds this, print it and the weight’s name.
- notify_detected_only: bool
- True -> print only if dead/NaN found False -> print a ‘not found’ message
-
destroy
(confirm=False, verbose=1)¶ Class ‘destructor’. Sets own, datagen’s, and val_datagen’s attributes to
[]
(which can free memory of arrays), then deletes them. Also deletes ‘model’ attribute, but this has no effect on memory allocation until it’s dereferenced globally and the TensorFlow/Keras graph is cleared (best bet is to restart the Python kernel).
-
fit_fn
¶
-
eval_fn
¶
-
epoch
¶
-
val_epoch
¶
-
_prepare_initial_data
(from_load=False)¶ Preloads first batch for training and validation, and superbatch if available.
-
_init_logger
()¶ Instantiate log directory for checkpointing. If
logdir
was provided at__init__
, will use it - else, will make a directory and assign its absolute path tologdir
.
-
_init_and_validate_kwargs
(kwargs)¶ Sets and validates
**kwargs
, raising exception if kwargs result in an invalid configuration, or correcting them (and possibly notifying) when possible. Also catches unused arguments.
-
_init_class_vars
()¶ Instantiates various internal attributes. Most of these are saved and loaded by default.
deeptrain.visuals module¶
-
deeptrain.visuals.
binary_preds_per_iteration
(_labels_cache, _preds_cache, w=1, h=1)¶ Plots binary preds vs. labels in a heatmap, separated by batches, grouped by slices.
To be used with sigmoid outputs (1 unit). Both inputs are to be shaped
(batches, slices, samples, *)
.- Arguments:
- _labels_cache: list[np.ndarray]
- List of labels cached during training/validation; insertion order
must match that of
_preds_cache
(i.e., first array should correspond to labels of same batch as predictions in_preds_cache
). - _preds_cache: list[np.ndarray]
- List of predictions cached during training/validation; see docs
on
_labels_cache
. - w, h: float
- Scale figure width & height, respectively.
-
deeptrain.visuals.
binary_preds_distribution
(_labels_cache, _preds_cache, pred_th, w=1, h=1)¶ Plots binary preds in a scatter plot, labeling dots according to their labels, and showing
pred_th
as a vertical line. Positive class (1) is labeled red, negative (0) blue; a red dot far left (close to 0) is hence a strongly misclassified positive class, and vice versa.To be used with sigmoid outputs (1 unit).
- Arguments:
- _labels_cache: list[np.ndarray]
- List of labels cached during training/validation; insertion order
must match that of
_preds_cache
(i.e., first array should correspond to labels of same batch as predictions in_preds_cache
). - _preds_cache: list[np.ndarray]
- List of predictions cached during training/validation; see docs
on
_labels_cache
. - pred_th: float
- Predict threshold (e.g. 0.5), plotted as a vertical line.
- w, h: float
- Scale figure width & height, respectively.
-
deeptrain.visuals.
infer_train_hist
(model, input_data, layer=None, keep_borders=True, bins=100, xlims=None, fontsize=14, vline=None, w=1, h=1)¶ Histograms of flattened layer output values in inference and train mode. Useful for comparing effect of layers that behave differently in train vs infrence modes (Dropout, BatchNormalization, etc) on model prediction (or intermediate activations).
- Arguments:
- model: models.Model / models.Sequential (keras / tf.keras)
- The model.
- input_data: np.ndarray / list[np.ndarray]
- Data to feed to
model
to fetch outputs. List of arrays for multi-input networks. - layer: layers.Layer / None
- Layer whose outputs to fetch; defaults to last layer (output).
- keep_borders: bool
- Whether to keep the plots’ bounding box.
- bins: int
- Number of histogram bins; kwarg to
plt.hist()
- xlims: tuple[float, float] / None
- Histogram x limits. Defaults to min/max of flattened data per plot.
- fontsize: int
- Title font size.
- vline: float / None
- x-coordinate of vertical line to draw (e.g. predict threshold).
- w, h: float
- Scale figure width & height, respectively.
-
deeptrain.visuals.
layer_hists
(model, _id='*', mode='weights', input_data=None, labels=None, omit_names='bias', share_xy=(0, 0), configs=None, **kw)¶ Histogram grid of layer weights, outputs, or gradients.
- Arguments:
- model: models.Model / models.Sequential (keras / tf.keras)
- The model.
- _id: str / int / list[str/int].
- int -> idx; str -> name
- idx: int. Index of layer to fetch, via model.layers[idx].
- name: str. Name of layer (full or substring) to be fetched. Returns earliest match if multiple found.
- list[str/int] -> treat each str element as name, int as idx.
Ex:
['gru', 2]
gets (e.g.) weights of first layer with name substring ‘gru’, then of layer w/ idx 2. '*'
(wildcard) -> get (e.g.) outputs of all layers (except input) with ‘output’ attribute.
- mode: str in (‘weights’, ‘outputs’, ‘gradients:weights’, ‘gradients:outputs’)
- Whether to fetch layer weights, outputs, or gradients (w.r.t. outputs or weights).
- input_data: np.ndarray / list[np.ndarray] / None
- Data to feed to
model
to fetch outputs / gradients. List of arrays for multi-input networks. Ignored formode='weights'
. - labels: np.ndarray / list[np.ndarray]
- Labels to feed to
model
to fetch gradients. List of arrays for multi-output networks. Ignored formode in ('weights', 'outputs')
. - omit_names: str / list[str] / tuple[str]
- Names of weights to omit for
_id
specifying layer names. E.g. forDense
,omit_names='bias'
will fetch only kernel weights. Ignored formode != 'weights'
. - share_xy: tuple[bool, bool] / tuple[str, str]
- Whether to share x or y limits in histogram grid, respectively.
kwarg to
plt.subplots()
; can be'col'
or'row'
for sharing along rows or columns, respectively. - configs: dict / None
kwargs to customize various plot schemes:
'plot'
: passed partly toax.hist()
insee_rnn.hist_clipped()
; includepeaks_to_clip
to adjust ylims with a number of peaks disregarded. Seehelp(see_rnn.hist_clipped)
. ax = subplots axis'subplot'
: passed toplt.subplots()
'title'
: passed tofig.suptitle()
; fig = subplots figure'tight'
: passed tofig.subplots_adjust()
'annot'
: passed toax.annotate()
'save'
: passed tofig.savefig()
ifsavepath
is not None.
- kw: dict / kwargs
- kwargs passed to
see_rnn.features_hist
.
-
deeptrain.visuals.
viz_roc_auc
(y_true, y_pred)¶ Plots the Receiver Operator Characteristic curve.
-
deeptrain.visuals.
get_history_fig
(self, plot_configs=None, w=1, h=1)¶ Plots train / validation history according to
plot_configs
.- Arguments:
- plot_configs: dict / None
- See
_DEFAULT_PLOT_CFG
. If None, defaults toTrainGenerator.plot_configs
(which itself defaults to_PLOT_CFG
inconfigs.py
). - w, h: float
- Scale figure width & height, respectively.
plot_configs
is structured as follows:>>> {'fig_kw': fig_kw, ... '0': {reserved_name: value, ... plt_kw: value}, ... '1': {reserved_name: value, ... plt_kw: value}, ... ...}
fig_kw
: dict, passed toplt.subplots(**fig_kw)
reserved_name
: str, one of('metrics', 'x_ticks', 'vhlines', 'mark_best_cfg', 'ylims', 'legend_kw')
. Used to configure supported custom plot behavior (see “Builtin plot customs” below).plt_kw
: str, name of kwarg to pass directly toplt.plot()
.value
: depends on key; see defaultplot_configs
in_DEFAULT_PLOT_CFG
andmisc._make_plot_configs_from_metrics()
.
Only
'metrics'
and'x_ticks'
keys are required for each dict - others have default values.Builtin plot customs: (
reserved_name
)'metrics'
(required): names of metrics to plot from histories, as{'train': train_metrics, 'val': val_metrics}
(at least one metric name required, for only one of train/val - need to have “something” to plot).x_ticks'
(required): x-coordinates of respective metrics, of samelen()
.'vhlines'
: dict[‘v’ / ‘h’: float]. vertical/horizontal lines; e.g.{'v': 10}
will draw a vertical line at x = 10, and{'h': .5}
at y = .5.'mark_best_cfg'
:{'train': metric_name}
or{'val': metric_name}
and (optional){'max_is_best: bool}
pairs. Will mark plot to indicate a metric optimum (max (if'max_is_best'
, the default) or min).'ylims'
: y-limits of plot panes.'legend_kw'
: passed toplt.legend()
; if None, no legend is drawn.
Defaults handling:
Keys and subkeys, where absent, will be filled from configs returned by
misc._make_plot_configs_from_metrics()
.- If plot pane
'0'
is lacking entirely, it’ll be copied from the defaults. - If subkey
'color'
in dict with key'0'
is missing, will fill fromdefaults['0']['color']
.
Further info:
- Every key’s iterable value (list, etc) must be of same len as number of
metrics in
'metrics'
; this is ensured withincfg_fn
. - Metrics are plotted in order of insertion (at both dict and list level),
so later metrics will carry over to additional plot panes if number of
metrics exceeds
plot_first_pane_max_vals
; seecfg_fn
. - A convenient option is to change
_PLOT_CFG
inconfigs.py
and passplot_configs=None
toTrainGenerator.__init__
; will internally callcfg_fn
, which validates some configs and tries to fill what’s missing. - Above,
cfg_fn
==misc._make_plot_configs_from_metrics()
-
deeptrain.visuals.
_plot_metrics
(x_ticks, metrics, plot_kw, mark_best_idx=None, max_is_best=True, axis=None, vhlines=None, ylims=(0, 2), legend_kw=None, key_metric='loss', metric_name_to_alias_fn=None)¶ Plots metrics according to inputs passed by
get_history_fig()
.
Module contents¶
-
deeptrain.
scalefig
(fig)¶ Used internally to scale figures according to env var ‘SCALEFIG’.
os.environ[‘SCALEFIG’] can be an int, float, tuple, list, or bracketless tuple, but must be a string: ‘1’, ‘1.1’, ‘(1, 1.1)’, ‘1,1.1’.
-
deeptrain.
set_seeds
(seeds=None, reset_graph=False, verbose=1)¶ Sets random seeds and maybe clears keras session and resets TensorFlow default graph.
NOTE: after calling w/
reset_graph=True
, best practice is to re-instantiate model, else some internal operations may fail due to elements from different graphs interacting (of pre-reset model and post-reset tensors).