deeptrain package¶
Subpackages¶
- deeptrain.backend package
- deeptrain.util package
- Subpackages
- Submodules
- deeptrain.util.algorithms module
- deeptrain.util.configs module
- deeptrain.util._default_configs module
- deeptrain.util.data_loaders module
- deeptrain.util.experimental module
- deeptrain.util.logging module
- deeptrain.util.misc module
- deeptrain.util.preprocessors module
- deeptrain.util.saving module
- deeptrain.util.searching module
- deeptrain.util.training module
- deeptrain.util._traingen_utils module
- Module contents
Submodules¶
deeptrain.callbacks module¶
-
deeptrain.callbacks.binary_preds_per_iteration_cb(self)¶ Suited for binary classification sigmoid outputs. See
binary_preds_per_iteration().
-
deeptrain.callbacks.binary_preds_distribution_cb(self)¶ Suited for binary classification sigmoid outputs. See
binary_preds_distribution().
-
deeptrain.callbacks.infer_train_hist_cb(self)¶ Suited for binary classification sigmoid outputs. See
infer_train_hist().
-
deeptrain.callbacks.make_layer_hists_cb(_id='*', mode='weights', x=None, y=None, omit_names='bias', share_xy=(0, 0), configs=None, **kw)¶ Layer histograms grid callback. See
layer_hists().
-
class
deeptrain.callbacks.TraingenCallback¶ Bases:
objectRequired base class for
callbacksobjects used byTrainGenerator. Enables usingTrainGeneratorattributes by assigning it toselfviainit_with_traingen. Methods are called byTrainGeneratorat several stages: train, validation, save, load,__init__.stageis an optional argument to every method (exceptinit_with_traingen) to allow finer-grained control, particularly for('val_end', 'train:epoch').Methods not implemented by the inheriting class will be skipped by catching
NotImplementedError.__init__sets a private attribute_counters, which can be used along a “freq” to call methods every Nth callback, instead of every callback. Example:RandomSeedSetter.-
init_with_traingen(traingen=None)¶ Called by
TrainGenerator.__init__(), passing inself(TrainGeneratorinstance).
-
on_train_iter_end(stage=None)¶ Called by
_on_iter_end, withinTrainGenerator._train_postiter_processing()withstage='train:iter'.
-
on_train_batch_end(stage=None)¶ Called by
_on_batch_end, withinTrainGenerator._train_postiter_processing(), withstage='train:batch'.
-
on_train_epoch_end(stage=None)¶ Called by
_on_epoch_end, withinTrainGenerator._train_postiter_processing(), withstage='train:epoch'.
-
on_val_iter_end(stage=None)¶ Called by
_on_iter_end, withinTrainGenerator._val_postiter_processing(), withstage='val:iter'.
-
on_val_batch_end(stage=None)¶ Called by
_on_batch_end, withinTrainGenerator._val_postiter_processing(), withstage='val:batch'.
-
on_val_epoch_end(stage=None)¶ Called by
_on_epoch_end, withinTrainGenerator._val_postiter_processing(), withstage='val:epoch'.
-
on_val_end(stage=None)¶ Called by
TrainGenerator._on_val_end(), with:stage=('val_end', 'train:epoch')ifTrainGenerator.datagen.all_data_exhaustedstage='val_end'otherwise
-
on_save(stage=None)¶ Called by
TrainGenerator.save()withstage='save'.
-
on_load(stage=None)¶ Called by
TrainGenerator.load()withstage='load'.
-
-
class
deeptrain.callbacks.RandomSeedSetter(seeds=None, freq={'train:epoch': 2})¶ Bases:
deeptrain.callbacks.TraingenCallbackPeriodically sets
random,numpy, TensorFlow graph-level, and TensorFlow global seeds.- Arguments:
- seeds: dict / None
- Dict of initial seeds; if None, will default all to 0. E.g.:
{'random': 0, 'numpy': 0, 'tf-graph': 1, 'tf-global': 2}. - freq: dict
- When to set / reset the seeds. E.g.
{'train:epoch': 2}(default) will set every 2 train epochs. By default, seeds are incremented by 1 to avoid repeating random sequences with each setting.
Methods can be overridden to increment by a different amount or not at all, and to clear keras session & reset TF default graph, which doesn’t happen by default.
-
seeds¶
-
set_seeds(increment=0, seeds=None, reset_graph=False, verbose=1)¶ Sets seeds.
incrementwill add to each ofself.seedsand update it. Ifseedsis not None, will overrideincrement.
-
classmethod
_set_seeds(seeds=None, reset_graph=False, verbose=1)¶ See
set_seeds();_set_seedscan be used without instantiating class, as is used indeeptrain.set_seeds().
-
classmethod
reset_graph(verbose=1)¶ Clears keras session, and resets TensorFlow default graph.
NOTE: after calling, best practice is to re-instantiate model, else some internal operations may fail due to elements from different graphs interacting (from docs: “using any previously created
tf.Operationortf.Tensorobjects after calling this function will result in undefined behavior”).
-
call_on_freq()¶
-
on_train_iter_end(stage)¶
-
on_train_batch_end(stage)¶
-
on_train_epoch_end(stage)¶
-
on_val_iter_end(stage)¶
-
on_val_batch_end(stage)¶
-
on_val_epoch_end(stage)¶
-
on_val_end(stage)¶
-
on_save(stage)¶
-
on_load(stage)¶
-
class
deeptrain.callbacks.VizAE2D(n_images=8, save_images=False)¶ Bases:
deeptrain.callbacks.TraingenCallbackImage AutoEncoder reconstruction visualizer.
Plots
n_imagesimages of original & reconstructed (model outputs) in two rows, side-by-side vertically.-
on_val_end(stage=None)¶ Called by
TrainGenerator._on_val_end(), with:stage=('val_end', 'train:epoch')ifTrainGenerator.datagen.all_data_exhaustedstage='val_end'otherwise
-
viz()¶
-
get_data()¶
-
-
class
deeptrain.callbacks.TraingenLogger(savedir, configs, loadpath=None, get_data_fn=None, get_labels_fn=None, gather_fns=None, logname='datalog_', init_log_id=None)¶ Bases:
deeptrain.callbacks.TraingenCallbackTensorBoard-like logger, gathering layer weights, outputs, and gradients over specified periods.
- Arguments:
- savedir: str
- Path to directory where to save data and class instance state.
- configs: dict
Data mode-layer/weight name pairs for logging. Ex:
>>> configs = { ... 'weights': ['dense_2', 'conv2d_1/kernel:0'], ... 'outputs': 'lstm_1', ... 'gradients': ('conv2d_1', 'dense_2:/bias:0'), ... 'gradients-kw': dict(mode='weights', learning_phase=0), ... }
With
mode='weights'for('weights', 'gradients'), complete weight names may be specified - else, all layer weights will be included. If layer name is substring, will return earliest match. Seesee_rnn.inspect_genmethods.- loadpath: str / None
- Path to savefile of class instance state, to resume logging.
If None, will attempt to fetch from
savedirbased onlogname. - get_data_fn: function
- Will call to fetch input data to feed for
'gradients'and'outputs'data. Defaults tolambda: TrainGenerator.val_datagen.get()[0]. - get_labels_fn: function
- Will call to fetch labels to feed for
'gradients'data. Defaults tolambda: TrainGenerator.val_datagen.get()[1]ifTrainGenerator.input_as_labels == False- else, to~get()[0]. - logname: str
- Base name of savefiles, prefixing save
_id. - init_log_id: int / None
- Initial log
_id, will increment by 1 with each call tolog(). (unless_idwas passed tolog()). Defaults to-1.
-
init_with_traingen(traingen)¶ Instantiates configs, getter functions, & others; requires a
TrainGeneratorinstance.
-
log(_id=None)¶ Gathers data according to
configsandgather_fns.
-
save(_id=None, clear=False, verbose=1)¶ Saves data per
_loggables, but not other class instance attributes.- Arguments:
- _id: int / str[int] / None
- Appended to
lognameto make savepath insavedir. - clear: bool
- Whether to empty data attributes after saving.
- verbose: bool / int[bool]
- Whether to print a save message with savepath.
-
load(verbose=1)¶ Loads from an .h5 file according to
loadpath, which can include data and other class instance attributes. Ifloadpathis None, will attempt to fetch fromsavedirbased onlogname.verbose==1/Trueto print load message with loadpath.
-
clear(verbose=1)¶ Set attributes in
_loggablesto{}
deeptrain.data_generator module¶
-
class
deeptrain.data_generator.DataGenerator(data_path, batch_size=32, labels_path=None, preprocessor=None, preprocessor_configs=None, data_loader=None, labels_loader=None, preload_labels=None, shuffle=False, superbatch_path=None, set_nums=None, superbatch_set_nums=None, **kwargs)¶ Bases:
objectCentral interface between a directory and
TrainGenerator. Handles data loading, preprocessing, shuffling, and batching. Requires onlydata_pathto run.- Arguments:
- data_path: str
- Path to directory to load data from.
- batch_size: int
- Number of samples to feed the model at once. Can differ from size of batches of loaded files; see “Dynamic batching”.
- labels_path: str / None
- Path to labels file. If None, will not load
labels; can be used withTrainGenerator.input_as_labels = True, feedingbatchaslabelsinTrainGenerator.get_data()(e.g. autoencoders). - preprocessor: None / custom object / str in (‘timeseries’,)
Transforms
batchandlabelsright before both are returned byget(). See_set_preprocessor().- str: fetches one of API-supported preprocessors.
- None, uses
GenericPreprocessor(). - Custom object: must subclass
Preprocessor().
- preprocessor_configs: None / dict
- Kwargs to pass to
preprocessorin case it’s None, str, or an uninstantiated custom object. Ignored ifpreprocessoris instantiated. - data_loader: None / function /
DataLoader() Object for loading data from directory / file.
- function: passed as
loadertoDataLoader.__init__in_infer_and_set_info(); input signature:(self, set_num) DataLoader()instance: will setdata_loaderdirectly- Class subclassing
DataLoader()(uninstantiated): will instantiate with attrs from_infer_info()& others - str: name of one of loaders in
util.data_loaders - None: defaults to one of defined in
util.data_loaders, as determined by_infer_info()
- function: passed as
- labels_loader: None / function /
DataLoader() data_loader, but for labels.- preload_labels: bool / None
- Whether to load all labels into
all_labelsat__init__. Defaults to True iflabels_pathis a file or a directory containing a single file. - shuffle: bool
- If True,
reset_state()will shuffleset_nums_to_process; the method is called by_on_epoch_endwithinTrainGenerator._train_postiter_processing()andTrainGenerator._val_postiter_processing()(viaon_epoch_end()). - superbatch_path: str / None
- Path to file or directory from which to load
superbatch(preload_superbatch()); see “Control Flow”. - set_nums: list[int] / None
- Used to set
set_nums_originalandset_nums_to_process. If None, will infer fromdata_path; see “Control Flow”. - superbatch_set_nums: list[int] / None
- set_nums to load into
superbatch; see “Control Flow”.
How it works:
Data is fed to
TrainGenerator()viaDataGenerator(). To work, data:- must be in one directory (or one file with all data)
- file extensions must be same (.npy, .h5, etc)
- file names must be enumerated with a common name (data1.npy, data2.npy, …)
- file batch size (# of samples, or dim 0 slices) should be same, but can also be in integer or fractal multiples of (x2, x3, x1/2, x1/3, …)
- labels must be in one file - unless feeding input as labels (e.g.
autoencoder), which doesn’t require labels files; just pass
TrainGenerator(input_as_labels=True)
Flexible batch_size:
Loaded file’s batch size may differ from
batch_size, so long as former is an integer or integer fraction multiple of latter. Ex:len(loaded) == 32,batch_size == 64-> will load another file and concatenate intolen(batch) == 64.len(loaded) == 64,batch_size == 32-> will set first half ofloadedasbatchand cacheloaded, then repeat for second half.- ‘Technically’, files need not be integer (/ fraction) multiples, as
the following load order works with
batch_size == 32:len(loaded) == 31,len(loaded) == 1- but this is not recommended, as it’s bound to fail if using shuffling, or if total number of samples isn’t divisible bybatch_size. Other problems may also arise.
Control Flow:
set_num: index / Mapping key used to identify and getbatchandlabelsviaload_data()andload_labels(). Ex: forDataLoader.numpy_loader(), which expects files shaped(batches, samples, *), it’d donp.load(path)[set_num].set_nums_to_process: will popset_numfrom this list until it’s empty; once empty, will setall_data_exhausted=True(update_state()).set_nums_original: will resetset_nums_to_processto this withreset_state(). It’s eitherset_numspassed to__init__or is inferred in_set_set_nums()as all availableset_numsindata_pathfile / directory.superbatch: dict ofset_num-batch`es loaded persistently in memory (RAM) as opposed to `batch, which is overwritten. Once loaded,batchcan be drawn straight fromsuperbatchifset_numis in it (i.e.superbatch_set_nums).get()returnsbatchandlabelsfed throughPreprocessor.process().advance_batch()gets “next”batchandlabels. “Next” is determined byset_num, which is popped fromset_nums_to_process[0].batch_exhausted: signalsTrainGenerator()that a batch was consumed; this information is set viaupdate_state()per_on_iter_endwithinTrainGenerator._train_postiter_processing()orTrainGenerator._val_postiter_processing().- If using slices (
slices_per_batch is not None), thenbatch_exhaustedis set to True only whenslice_idx == slices_per_batch - 1.
__init__:Instantiation. (“+” == if certain conditions are met)
- +Infers missing configs based on args
- Validates args & kwargs, and tries to correct, printing a”NOTE” or “WARNING” message where appropriate
- +Preloads all labels into
all_labels - Instantiates misc internal parameters to predefiend values (may be
overridden by
TrainGeneratorloading).
-
_BUILTINS= {'extensions': {'.csv', '.h5', '.npy'}, 'loaders': {'csv', 'hdf5', 'numpy', 'numpy-lz4f', 'numpy-memmap'}, 'preprocessors': (<class 'deeptrain.util.preprocessors.GenericPreprocessor'>, <class 'deeptrain.util.preprocessors.TimeseriesPreprocessor'>)}¶
-
get(skip_validation=False)¶ Returns
(batch, labels)fed toPreprocessor.process().- skip_validation: bool
- False (default): calls
_validate_batch(), which willadvance_batch()ifbatch_exhausted, andreset_state()ifall_data_exhausted. - True: fetch preprocessed
(batch, labels)without advancing any internal states.
- False (default): calls
-
advance_batch(forced=False, is_recursive=False)¶ Sets next
batchandlabels; handles dynamic batching.- If
batch_loadedand notforced(and notis_recursive), prints a warning that batch is loaded, and returns (does nothing) len(batch) != batch_size:< batch_size: callsadvance_batch()withis_recursive = True. With each such call,batchandlabelsare extended (stacked) until matchingbatch_size.> batch_size, not integer multiple: raisesException.> batch_size, is integer multiple: makes_group_batchand_group_labels, which are used to setbatchandlabels.
- +If
set_nums_to_processis empty, will raiseException; it must have been reset beforehand via e.g.reset_state(). If it’s not empty, setsset_numby popping fromset_nums_to_process. (+: only if_group_batchis None) - Sets or extends
batchvia_get_next_batch()(by loading, or from_group_batchorsuperbatch). - +Sets or extends
labelsvia_get_next_labels()(by loading, or from_group_labels, orall_labels). (+: only iflabels_pathis a path (and not None)) - Sets
set_name, used byTrainGenerator()to print iteration messages. - Sets
batch_loaded = True,batch_exhausted = False,all_data_exhausted = False, andslice_idxto None if it’s already None (else to0).
- If
-
_get_next_batch(set_num=None, warn=True)¶ Gets
batchperset_num.set_num = None: will useself.set_num.warn = False: won’t print warning on superbatch not being preloaded.- If
_group_batchis not None, will getbatchfrom_group_batch. - If
set_numis insuperbatch_set_nums, will getbatchassuperbatch[set_num](ifsuperbatchexists). - By default, gets
batchviaload_data().
-
_get_next_labels(set_num=None)¶ Gets
labelsperset_num.set_num = None: will useself.set_num.- If
_group_labelsis not None, will getlabelsfrom_group_labels. - If
set_numis insuperbatch_set_nums, will getbatchassuperbatch[set_num](ifsuperbatchexists). - By default, gets
labelsviaload_data(), iflabels_pathis set - else,labels=[].
-
_batch_from_group_batch()¶ Slices
_group_batchperbatch_sizeand_group_batch_idx.
-
_labels_from_group_labels()¶ Slices
_group_labelsperbatch_sizeand_group_batch_idx.
-
_update_group_batch_state()¶ Sets “group” attributes to
Noneonce sufficient number of batches were extracted, else increments_group_batch_idx.
-
on_epoch_end()¶ Increments
epoch, callspreprocessor.on_epoch_end(epoch), thenreset_state(), and returnsepoch.
-
update_state()¶ Calls
preprocessor.update_state(), and ifbatch_exhaustedandset_nums_to_process == [], setsall_data_exhausted = Trueto signalTrainGenerator()of epoch end.
-
reset_state(shuffle=None)¶ Calls
preprocessor.reset_state(), setsbatch_exhausted = True,batch_loaded = False, resetsset_nums_to_processtoset_nums_original, and shufflesset_nums_to_processifshuffle.- If
shufflepassed in is None, will set fromself.shuffle. - Used in
TrainGenerator.reset_validation()w/shuffle=False.
- If
-
_validate_batch()¶ If
all_data_exhausted, callsreset_state(). Ifbatch_exhausted, callsadvance_batch().
-
_make_group_batch_and_labels(n_batches)¶ Makes
_group_batchand_group_labelswhen loadedlen(batch)exceedsbatch_sizeas its integer multiple. May shuffle._group_batch = np.asarray(batch), and_group_labels = np.asarray(labels); each’slen() > batch_size.- Shuffles if:
shuffle_group_samples: shuffles all samples (dim0 slices)shuffle_group_batches: groups dim0 slices bybatch_size, then shuffles the groupings. Ex:>>> batch_size == 32 >>> batch.shape == (128, 100) >>> batch = batch.reshape() # (4, 32, 100) == .shape >>> shuffle(batch) # 24 (4!) permutations >>> batch = batch.reshape() # (128, 100) == .shape
Sets
_group_batch_idx = 0, and calls_update_group_batch_state().Doesn’t affect
labelsiflabels_pathis falsy (e.g. None)
-
batch_exhausted¶ Is retrieved from and set in
preprocessor. Indicates thatbatchandlabelsfor givenset_numwere consumed byTrainGenerator(if using slices, that all slices were consumed).Ex:
self.batch_exhausted = 5will setself.preprocessor.batch_exhausted = 5, andprint(self.batch_exhausted)will then print5(or something else ifpreprocessorchanges it internally).
-
batch_loaded¶ Is retrieved from and set in
preprocessor, same asbatch_exhausted. Indicates thatbatchandlabelsfor givenset_numare loaded.
-
slices_per_batch¶ Is retrieved from and set in
preprocessor, same asbatch_exhausted.
-
slice_idx¶ Is retrieved from and set in
preprocessor, same asbatch_exhausted.
-
load_data¶ Load and return
batchdata viadata_loaders.DataLoader.load_fn(). Used by_get_next_batch()andpreload_superbatch().
-
load_labels¶ Load and return
labelsdata viadata_loaders.DataLoader.load_fn(). Used by_get_next_labels()andpreload_all_labels().
-
_infer_info(path)¶ Infers unspecified essential attributes from directory and contained files info:
- Checks that the data directory (
path) isn’t empty (files whose names start with'.'aren’t counted) - Retrieves data filepaths per
pathand gets data extension (to most frequent ext in dir, excl. “other path” from count if in same dir. “other path” isdata_pathifpath == labels_path, and vice versa.) - Gets
base_nameas longest common substring among files withextextension - If
pathis path to a file, thenfilepaths=[path]. - If
pathis None, returnsbase_name=Noneext=None,filepaths=[].
- Checks that the data directory (
-
_infer_and_set_info(data_loader, labels_loader)¶ Sets
data_loaderandlabels_loader(DataLoader()), using info obtained from_infer_info().- If
infocontains only one filepath, loader will operate with_is_dataset=True. - If
preload_labelsis None andlabels_loader._is_dataset, will setpreload_labels=True.
data_loader/labels_loaderare:- function: passed as
loadertoDataLoader.__init__; input signature:(self, set_num) DataLoader()instance: will setdata_loaderdirectly- Class subclassing
DataLoader()(uninstantiated): will instantiate with attrs from_infer_info()& others - str: name of one of loaders in
util.data_loaders - None: defaults to one of defined in
util.data_loaders, as determined by_infer_info()
- If
-
_set_set_nums(set_nums, superbatch_set_nums)¶ Sets
set_nums_original,set_nums_to_process, andsuperbatch_set_nums.- Fetches
set_numsviaDataLoader._get_set_nums() - Sets
set_nums_to_processandset_nums_original; ifset_numsweren’t passed to__init__, sets to fetched ones. - If
set_numswere passed, validates that they’re a subset of fetched ones (i.e. can be seen bydata_loader). - Sets
superbatch_set_nums; if not passed to__init__, and== 'all', sets to fetched ones. If passed, validates that they subset fetched ones. - Does not validate set_nums from
labels_loader’s perspective; user is expected to supply alabelsto eachbatchwith commonset_num.
- Fetches
-
_set_preprocessor(preprocessor, preprocessor_configs)¶ Sets
preprocessor, based onpreprocessorpassed to__init__:- If None, sets to
GenericPreprocessor(), instantiated withpreprocessor_configs. - If an uninstantiated class, will validate that it subclasses
Preprocessor(), then isntantiate withpreprocessor_configs. - If string, will match to a supported builtin.
- Validates that the set
preprocessorsubclassesPreprocessor().
- If None, sets to
-
preload_superbatch()¶ Loads all data specified by
superbatch_set_numsviaload_data(), and assigns them tosuperbatchfor eachset_num.
-
preload_all_labels()¶ Loads all labels into
all_labelsusingload_labels(), based onset_nums_original.
-
_init_and_validate_kwargs(kwargs)¶ Sets and validates
kwargspassed to__init__.- Ensures
data_pathis a file or a directory, andlabels_pathis a file, directory, or None. - Ensures kwargs are functional (compares against names in
_DEFAULT_DATAGEN_CFG. - Sets whichever names were passed with
kwargs, and defaults the rest.
- Ensures
-
_init_class_vars()¶ Instantiates various internal attributes. Most of these are saved and loaded by
TrainGenerator()by default.
deeptrain.introspection module¶
-
deeptrain.introspection.compute_gradient_norm(self, input_data, labels, sample_weight=None, learning_phase=0, _id='*', mode='weights', norm_fn=(<ufunc 'sqrt'>, <ufunc 'square'>), scope='local')¶ Computes gradients w.r.t. layer weights or outputs per
_id, and returns norm according tonorm_fnandscope.- Arguments:
- input_data: np.ndarray / list[np.ndarray] / supported formats
- Data w.r.t. which loss is to be computed for the gradient.
List of arrays for multi-input networks. “Supported formats”
is any valid input to
model. - labels: np.ndarray / list[np.ndarray] / supported formats
- Labels w.r.t. which loss is to be computed for the gradient.
- sample_weight: np.ndarray / list[np.ndarray] / supported formats
- kwarg to
model.fit(), etc., weighting individual sample losses. - learning_phase: bool / int[bool]
- 1: use model in train mode
- 0: use model in inference mode
- _id: str / int / list[str/int].
- int -> idx; str -> name
- idx: int. Index of layer to fetch, via model.layers[idx].
- name: str. Name of layer (full or substring) to be fetched. Returns earliest match if multiple found.
- list[str/int] -> treat each str element as name, int as idx.
Ex:
['gru', 2]gets (e.g.) weights of first layer with name substring ‘gru’, then of layer w/ idx 2. '*'(wildcard) -> get (e.g.) outputs of all layers (except input) with ‘output’ attribute.
- mode: str in (‘weights’, ‘outputs’, ‘gradients:weights’, ‘gradients:outputs’)
- Whether to fetch layer weights, outputs, or gradients (w.r.t. outputs or weights).
- norm_fn: (function, function) / function
- Norm function(s) to apply to gradients arrays when gathering.
(np.sqrt, np.square)for L2-norm,np.absfor L1-norm. Computed as:outer_fn(sum(inner_fn(x) for x in data)), whereouter_fn, inner_fn = norm_fnifnorm_fnis list/tuple, andinner_fn = norm_fnandouter_fn = lambda x: xotherwise. - scope: str in (‘local’, ‘global’)
- Whether to apply
stat_fnon individual gradient arrays, or sum of.
- Returns:
- Gradient norm(s). List of float if
scope == 'local'(norms of weights), else float (outer_fn(sum(sum(inner_fn(g)) for g in grads))).
TensorFlow optimizers do gradient clipping according to the
clipnormsetting by comparing individual weights’ L2-norms againstclipnorm, and rescaling if exceeding. These L2 norms can be obtained usingnorm_fn=(np.sqrt, np.square)withscope == 'local'andmode='weights'. See:tensorflow.python.keras.optimizer_v2.optimizer_v2._clip_gradientskeras.optimizers.clip_normtensorflow.python.ops.clip_ops.clip_by_norm
-
deeptrain.introspection.gradient_norm_over_dataset(self, val=False, learning_phase=0, mode='weights', norm_fn=(<ufunc 'sqrt'>, <ufunc 'square'>), stat_fn=<function median>, n_iters=None, prog_freq=10, w=1, h=1)¶ Aggregates gradient norms over dataset, one iteration at a time. Useful for estimating value of gradient clipping,
clipnorm, to use. Plots a histogram of gathered data when finished. Also seecompute_gradient_norm().- Arguments:
- val: bool
- True: gather over
val_datagenbatches - False: gather over
datagenbatches
- True: gather over
- learning_phase: bool / int[bool]
- True: get gradients of model in train mode
- False: get gradients of model in inference mode
- mode: str in (‘weights’, ‘outputs’)
- Whether to get gradients with respect to layer weights or outputs.
- norm_fn: (function, function) / function
- Norm function(s) to apply to gradients arrays when gathering.
(np.sqrt, np.square)for L2-norm,np.absfor L1-norm. Computed as:outer_fn(sum(inner_fn(g) for g in grads)), whereouter_fn, inner_fn = norm_fnifnorm_fnis list/tuple, andinner_fn = norm_fnandouter_fn = lambda x: xotherwise. - stat_fn: function
- Aggregate function to apply on computed norms. If
np.mean, will gather mean of gradients; ifnp.median, the median, etc. Computed as:stat_fn(outer_fn(sum(inner_fn(g) for g in grads))). - n_iters: int / None
- Number of expected iterations over entire dataset. Can be used to
iterate over subset of entire dataset. If None, will return upon
DataGenerator.all_data_exhausted. - prog_freq: int
- How often to print
f'|{batch_idx}', and'.'otherwise, in terms of number of batches (not iterations, but are same if not using slices). E.g. 5:....|5....|10....|15. - w, h: float
- Scale figure width & height, respectively.
- Returns:
- grad_norms: np.ndarray
- Norms of gradients for every iteration.
Shape:
(iters_processed, n_params), wheren_paramsis number of gradient arrays whose norm stats were computed at each iteration. - batches_processed: int
- Number of batches processed.
- iters_processed: int
- Number of iterations processed (if using e.g. 4 slices per batch,
will equal
4 * batches_processed).
-
deeptrain.introspection.gradient_sum_over_dataset(self, val=False, learning_phase=0, mode='weights', n_iters=None, prog_freq=10, plot_kw={})¶ Computes cumulative sum of gradients over dataset, one iteration at a time, preserving full array shapes. Useful for computing mean of gradients over dataset, or other aggregate metrics.
- Arguments:
- val: bool
- True: gather over
val_datagenbatches - False: gather over
datagenbatches
- True: gather over
- learning_phase: bool / int[bool]
- True: get gradients of model in train mode
- False: get gradients of model in inference mode
- mode: str in (‘weights’, ‘outputs’)
- Whether to get gradients with respect to layer weights or outputs.
- n_iters: int / None
- Number of expected iterations over entire dataset. Can be used to
iterate over subset of entire dataset. If None, will return upon
DataGenerator.all_data_exhausted. - prog_freq: int
- How often to print
f'|{batch_idx}', and'.'otherwise, in terms of number of batches (not iterations, but are same if not using slices). E.g. 5:....|5....|10....|15. - plot_kw: dict
- Kwargs to pass to
see_rnn.features_hist; defaults to{'share_xy': False, 'center_zero': True}.
- Returns:
- grad_sum: dict[str: np.ndarray]
- Gradient arrays summed over dataset. Structure:
{name: array, name: array, ...}, wherenameis name of weight array or layer output. - batches_processed: int
- Number of batches processed.
- iters_processed: int
- Number of iterations processed (if using e.g. 4 slices per batch,
will equal
4 * batches_processed).
-
deeptrain.introspection._gather_over_dataset(self, gather_fn, val=False, n_iters=None, prog_freq=10)¶ Iterates over
DataGenerator, applyinggather_fnto every batch (or slice). Stops aftern_iters, or whenDataGenerator.all_data_exhaustedifn_iters is None. Useful for monitoring quantities over the course of training or inference,.gather_fnrecursively updatesdata; as such, it can be used to append to a list, update a dictionary, operate on an array, etc. Review source code for exact logic.
-
deeptrain.introspection._make_gradients_fn(model, learning_phase, mode, return_names=False)¶ Makes reusable gradient-getter function, separately for TF Eager & Graph execution. Eager variant is pseudo-reusable; gradient tensors are still fetched all over - graph should be significantly faster.
-
deeptrain.introspection.print_dead_weights(model, dead_threshold=1e-07, notify_above_frac=0.001, notify_detected_only=False)¶ Print names of dead weights and their proportions. Useful for debugging vanishing and exploding gradients, or quantifying sparsity.
- Arguments:
- model: models.Model / models.Sequential (keras / tf.keras)
- The model.
- dead_threshold: float
- Threshold below which to count the weight as “dead”, in absolute value.
- notify_above_frac: float
- Print only if fraction of weights counted “dead” exceeds this
(e.g. if there are 11 absolute values <
dead_thresholdout of 1000). - notify_detected_only: bool
- True: print text only if dead weights are discovered
- False: print a “not found given thresholds” message when appropriate
-
deeptrain.introspection.print_nan_weights(model, notify_detected_only=False)¶ Print names of NaN/Inf weights and their proportions. Useful for debugging exploding or buggy gradients.
- Arguments:
- model: models.Model / models.Sequential (keras / tf.keras)
- The model.
- notify_detected_only: bool
- True: print text only if dead weights are discovered
- False: print a “none found” message if no NaNs were found
-
deeptrain.introspection.print_large_weights(model, large_threshold=3, notify_above_frac=0.001, notify_detected_only=False)¶ Print names of weights in excess of set absolute value, and their proportions; excludes Inf. Useful for debugging exploding or buggy gradients.
- Arguments:
- model: models.Model / models.Sequential (keras / tf.keras)
- The model.
- large_threshold: float
- Threshold above which to count the weight’s absolute value as “large”.
- notify_above_frac: float
Print only if fraction of weights counted “large” exceeds this (e.g. if there are 11 absolute values <
large_thresholdout of 1000).- notify_detected_only: bool
- True: print text only if dead weights are discovered
- False: print a “none found” message if no NaNs were found
-
deeptrain.introspection._sample_weight_built(model)¶ In Graph execution,
model._feed_sample_weightsisn’t built unless model is compiled withsample_weight_modeset, ortrain_on_batchortest_on_batchis called w/sample_weightpassed.
-
deeptrain.introspection.interrupt_status(self) -> (<class 'bool'>, <class 'bool'>)¶ Prints whether
TrainGeneratorwas interrupted (e.g.KeyboardInterrupt, or via exception) duringtrain()andvalidate(). Returns bools (True for interrupted, else False) for each, as (train, val).Not foolproof; user can set flags manually or via callbacks. For further assurance, check
temp_history,val_temp_history, and cache attributes (e.g._preds_cache) which are cleared at end ofvalidate()by default; this method checks only flags:_train_loop_done,train_postiter_processed,_val_loop_done,_val_postiter_processed.
-
deeptrain.introspection.info(self)¶ Prints various useful TrainGenerator & DataGenerator attributes, and interrupt status.
deeptrain.metrics module¶
-
deeptrain.metrics.f1_score(y_true, y_pred, pred_threshold=0.5, beta=1)¶
-
deeptrain.metrics.mean_squared_error(y_true, y_pred, sample_weight=1)¶
-
deeptrain.metrics.mean_absolute_error(y_true, y_pred, sample_weight=1)¶
-
deeptrain.metrics.roc_auc_score(y_true, y_pred)¶
deeptrain.preprocessing module¶
-
deeptrain.preprocessing.data_to_hdf5(savepath, batch_size, loaddir=None, data=None, shuffle=False, compression='lzf', dtype=None, load_fn=None, oversample_remainder=True, batches_dim0=False, overwrite=None, verbose=1)¶ Convert data to hdf5-group (.h5) format, in
batch_sizesample sets.- Arguments:
- savepath: str
- Absolute path to where to save file.
- batch_size: int
- Number of samples (dim0 slices) to save per file.
- loaddir: str
- Absolute path to directory from which to load data.
- data: np.ndarray / list[np.ndarray]
- Shape:
(samples, *)or(batches, samples, *)(must usebatches_dim0=True. With former, iflen(data) == 320andbatch_size == 32, will make a 10-set .h5 file. - shuffle: bool
- Whether to shuffle samples (dim0 slices).
- compression: str
- Compression type to use. kwarg to
h5py.File().create_dataset(). - dtype: str / np.dtype
- Savefile dtype; kwarg to
.create_dataset(). Defaults to data’s dtype. - load_fn: function / callable
- Used on supported paths (.npy) in
loaddirto load data. - oversample_remainder: bool. Relevant only when passing
data. - True -> randomly draw (remainer - batch_size) samples to fill incomplete batch.
- False -> drop remainder samples.
- batches_dim0: bool
- Assume shapes - True:
(batches, samples, *); False:(samples, *). - overwrite: bool / None
If
savepathfile exists,- True -> replace it
- False -> don’t replace it
- None -> ask confirmation via user input
- verbose: bool
- Whether to print preprocessing progress.
Notes:
- If supplying
loaddirinstead ofdata, will iteratively load files with supported format (.npy).len()of loaded file must be an integer fraction multiple ofbatch_size, <= 1. Sobatch_size == 32andlen() == 16works, butlen() == 48orlen() == 24doesn’t.
-
deeptrain.preprocessing.numpy_to_lz4f(data, savepath=None, level=9, overwrite=None)¶ Do lz4-framed compression on
data. (Install compressor via!pip install py-lz4framed)- Arguments:
- data: np.ndarray
- Data to compress.
- savepath: str
- Path to where to save file.
- level: int
- 1 to 9; higher = greater compression
- overwrite: bool
If
savepathfile exists,- True -> replace it
- False -> don’t replace it
- None -> ask confirmation via user input
- Returns:
- np.ndarray - compressed array.
Example:
>>> numpy_to_lz4f(savedata, savepath=path) ... >>> # load & decompress >>> bytes_npy = lz4f.decompress(np.load(path)) >>> loaddata = np.frombuffer(bytes_npy, ... dtype=savedata.dtype, # must be original's ... ).reshape(*savedata.shape)
-
deeptrain.preprocessing.numpy_data_to_numpy_sets(data, labels, savedir=None, batch_size=32, shuffle=True, data_basename='batch', oversample_remainder=True, overwrite=None, verbose=1)¶ Save
datainbatch_sizechunks, possibly shuffling samples.- Arguments:
- data: np.ndarray
- Data to batch along
labels& save. - labels: np.ndarray
- Labels to batch along
dataand save. - savedir: str / None
- Directory in which to save processed data. If None, won’t save.
- batch_size: int
- Number of samples (dim0 slices) to form a ‘set’ with
- data_basename: str
- Will save with this prepending set numbering - e.g.: ‘batch__1.npy’, ‘batch__2.npy’ …
- oversample_remainder: bool
- True -> randomly draw (remainer - batch_size) samples to fill incomplete batch.
- False -> drop remainder samples.
- overwrite: bool / None
If
savepathfile exists,- True -> replace it
- False -> don’t replace it
- None -> ask confirmation via user input
- verbose: bool
- Whether to print preprocessing progress.
- Returns:
- data, labels: processed
data&labels.
-
deeptrain.preprocessing.numpy2D_to_csv(data, savepath=None, batch_size=None, columns=None, sample_dim=1, overwrite=None)¶ Save 2D data as .csv.
- Arguments:
- data: np.ndarray
- Data to save, shaped
(batches, samples). - savepath: str
- Path to where save to file.
- batch_size: int
- Number of rows per column; can differ from
data’s. - columns: list of str
- Column names for the data frame; defaults to enumerate columns (0, 1, 2, …)
- sample_dim: int
- Dimension applicable to
batch_size. - overwrite: bool
If
savepathfile exists,- True -> replace it
- False -> don’t replace it
- None -> ask confirmation via user input
Example:
Suppose we have
labels, 16 per batch, and 8 batches. Each batch is shaped (16,) - stacked, (8, 16). Dim 1 is thussample_dim. Also see examples/preprocessing/timeseries.py.>>> data.shape == (8, 16) # (num_batches, samples) >>> numpy2D_to_csv(data, "data.csv", batch_size=32, batch_dim=1) >>> # if it was (samples, num_batches), sample_dim would be 0. ... # This will make a DataFrame of 4 columns, 32 rows per column, ... # reshaping `data` to correctly concatenate samples from dim0 to dim1.
deeptrain.train_generator module¶
-
class
deeptrain.train_generator.TrainGenerator(model, datagen, val_datagen, epochs=1, logs_dir=None, best_models_dir=None, loadpath=None, callbacks=None, fit_fn='train_on_batch', eval_fn='evaluate', key_metric='loss', key_metric_fn=None, val_metrics=None, custom_metrics=None, input_as_labels=False, max_is_best=None, val_freq={'epoch': 1}, plot_history_freq={'epoch': 1}, unique_checkpoint_freq={'epoch': 1}, temp_checkpoint_freq=None, class_weights=None, val_class_weights=None, reset_statefuls=False, iter_verbosity=1, logdir=None, optimizer_save_configs=None, optimizer_load_configs=None, plot_configs=None, model_configs=None, **kwargs)¶ Bases:
deeptrain.util._traingen_utils.TraingenUtilsThe central DeepTrain class. Interfaces training, validation, checkpointing, data loading, and progress tracking.
- Arguments:
- model: models.Model / models.Sequential [keras / tf.keras]
- Compiled model to train.
- datagen:
DataGenerator - Train data generator; fetches inputs and labels, handles preprocessing, shuffling, stateful formats, and informing TrainGenerator when a dataset is exhausted (epoch end).
- val_datagen:
DataGenerator - Validation data generator.
- epochs: int
- Number of train epochs.
- logs_dir: str / None
- Path to directory where to generate log directories, that include
TrainGenerator state, state report, model data, and others; see
checkpoint(). IfNone, will not checkpoint - but model saving still possible viabest_models_dir. - best_models_dir: str / None
- Path to directory where to save best model. “Best” means having new
highest (
max_is_best==True) or lowest (max_is_best==False) entry inkey_metric_history. See_save_best_model(). - loadpath: str / None
- Path to .h5 file containing TrainGenerator state to load (postfixed
'__state.h5'by default). Seeload(). - callbacks: dict[str: function] /
TraingenCallback/ None - Functions to apply at various stages, including training, validation,
saving, loading, and
__init__. SeeTraingenCallback. - fit_fn: str /
function(x, y, sample_weight) - Function, or name of model method to feed data to during training;
if str, will define
fit_fn = getattr(model, 'fit')(example). If function, its name (substring) must include'fit'or'train'(currently both function identically). - eval_fn: str /
function(x, y, sample_weight) Function, or name of model method to feed data to during validation; if str, will define
eval_fn = getattr(model, 'evaluate')(example). If function, its name (substring) must include'evaluate'or'predict':'evaluate':eval_fnuses data & labels to return metrics.'predict':eval_fnuses data to return predictions, which are used internally to compute metrics.
- key_metric: str
- Name of metric to track for saving best model; will store in
key_metric_history. See_save_best_model(). - key_metric_fn: function / None
- Custom function to compute key metric; overrides
key_metricif not None. - val_metrics: list[str] / None
Names of metrics to track during validation.
- If
'predict'is not ineval_fn.__name__, is overridden by model metrics (model.compile(metrics=...)) - If
'loss'is not included, will prepend. - If
'*'is included, will insert model metrics at its position and pop'*'. Ex:[*, 'f1_score']->['loss', 'accuracy', 'f1_score'].
- If
- custom_metrics: dict[str: function]
Name-function pairs of custom functions to use for gathering metrics. Functions must obey
(y_true, y_pred)input signature for first two arguments. They may additionally supplysample_weightandpred_threshold, which will be detected and used automatically.- Note: if using a custom metric in
model.compile(loss=tf_fn), name incustom_metricsmust be function’s code name, i.e.{tf_fn.__name__: fn}(wherefnis e.g. numpy version).
- Note: if using a custom metric in
- input_as_labels: bool
- Feed model input also to its output. Ex: autoencoders.
- max_is_best: bool
- Whether to consider greater
key_metricas better in saving best model. See_save_best_model(). If None, defaults to False ifkey_metric=='loss', else True. - val_freq: None / dict[str: int], str in {‘iter’, ‘batch’, ‘epoch’, ‘val’}
- How frequently to validate.
{'epoch': 1}-> every epoch;{'iter': 24}-> every 24 train iterations. Only one key-value pair supported. If None, won’t validate. - plot_history_freq: None / dict[str: int], str in {‘iter’, ‘batch’, ‘epoch’, ‘val’}
- How frequently to plot train & validation history. Only one key-value pair supported. If None, won’t plot history.
- unique_checkpoint_freq: None / dict[str: int], str in {‘iter’, ‘batch’, ‘epoch’, ‘val’}
- How frequently to make checkpoints with unique savefile names, as opposed to temporary ones which are overwritten each time. Only one key-value pair supported. If None, won’t make unique checkpoints.
- temp_checkpoint_freq: None / dict[str: int], str in {‘iter’, ‘batch’, ‘epoch’, ‘val’}
- How frequently to make checkpoints with the same predefined name, to be overwritten (“temporary”); serves as an intermediate checkpoint to unique ones, if needed. Only one key-value pair supported. If None, won’t make temporary checkpoints.
- class_weights: dict[int: int] / None
Integer-mapping of class labels to their “weights”; if not None, will feed
sample_weightmediated by the weights to train function (fit_fn).>>> class_weights = {0: 4, 1: 1} >>> labels == [1, 1, 0, 1] # if >>> sample_weight == [4, 4, 1, 4] # then
- val_class_weights: dict[int: int] / None
class_weightsfor validation function (eval_fn).- reset_statefuls: bool
- Whether to call
model.reset_states()at the end of every batch (train and val). - iter_verbosity: int in {0, 1, 2}
- 0: print no iteration info
- 1: print name of set being fit / validated, metric names and values,
- and
model.reset_states()being called
- 2: print a
'.'at every iteration (useful if having multiple iterations per batch)
- logdir: str / None
- Directory where to write logs to (see
logs_dir). Use to specify an existing directory (to, for example, resume training and logging in original folder). Overrideslogs_dir. - optimizer_save_configs: dict / None
- Dict specifying which optimizer attributes to include or exclude
when saving. See
save(). - optimizer_load_configs: dict / None
- Dict specifying which optimizer attributes to include or exclude
when loading. See
load(). - plot_configs: dict / None
- Dict specifying
get_history_fig()behavior. See_DEFAULT_PLOT_CFG, and_make_plot_configs_from_metrics(). - model_configs: dict / None
Dict specifying model information. Intended usage is: create
modelaccording to the dict, specifying hyperparameters, loss function, etc.>>> def make_model(batch_shape, units, optimizer, loss): ... ipt = Input(batch_shape=batch_shape) ... out = Dense(units)(ipt) ... model = Model(ipt, out) ... model.compile(optimizer, loss) ... return model ... >>> model_configs = {'batch_shape': (32, 16), 'units': 8, ... 'optimizer': 'adam', 'loss': 'mse'} >>> model = make_model(**model_configs)
Checkpoints will include an image report with the entire dict; the larger the portion of the model that’s created according to
model_configs, the more will be documented for easy reference.- kwargs: keyword arguments.
- See
_DEFAULT_TRAINGEN_CFG.kwargsand all other arguments are subject to validation and correction by_validate_traingen_configs().
__init__:Instantiation. (“+” == if certain conditions are met)
- Pulls methods from
TraingenUtils - Validates args & kwargs, and tries to correct, printing a”NOTE” or “WARNING” message where appropriate
- +Instantiates logging directory
- +Loads
TrainGenerator,datagen, andval_datagenstates - +Loads model and optimizer weights (but not model architecture)
- +Preloads train & validation data (before a call to
train()is made). - +Applies initial callbacks
- +Logs initial state (
_log_init_state()) - Captures and saves all arguments passed to
__init__ - Instantiates misc internal parameters to predefiend values (may be overridden by loading).
-
train()¶ The train loop.
- Fetches data from
get_data - Fits data via
fin_fn - Processes fit metrics in
_train_postiter_processing - Stores metrics in
history - Applies
'train:iter','train:batch', and'train:epoch'callbacks - Calls
validatewhen appropriate
Interruption:
- Safe: during
get_data, which can be called indefinitely without changing any attributes. - Avoid: during
_train_postiter_processing, wherefit_fnis applied and weights are updated - but metrics aren’t stored, and_train_postiter_processed=False, restarting the loop without recording progress. - Best bet is during
validate(), asget_datamay be too brief.
- Fetches data from
-
validate(record_progress=True, clear_cache=True, restart=False, use_callbacks=True)¶ Validation loop.
- Fetches data from
get_data - Applies function based on
_eval_fn_name - Processes and caches metrics/predictions in
_val_postiter_processing - Applies
'val:iter','val:batch', and'val:epoch'callbacks - Calls
_on_val_endat end of validation to compute metrics and store them inval_history - Applies
'val_end'and maybe('val_end': 'train:epoch')callbacks - If
restart, callsreset_validation().
- Arguments:
- record_progress: bool
- If False, won’t update
val_history,_val_iters,_batches_validated. - clear_cache: bool
- If False, won’t call
clear_cache(); useful for keeping preds & labels acquired during validation. - restart: bool
- If True, will call
reset_valiation()before validation loop to reset validation attributes; useful for starting afresh (e.g. if interrupted). - use_callbacks: bool
- If False, won’t call
apply_callbacks()orplot_history().
Interruption:
- Safe: during
get_data, which can be called indefinitely without changing any attributes. - Avoid: during
_val_postiter_processing. Model remains unaffected*, but caches are updated; a restart may yield duplicate appending, which will error or yield inaccuracies. (* forward pass may consume random seed if random ops are used) - In practice: prefer interrupting immediately after
_print_iter_progressexecutes.
- Fetches data from
-
_train_postiter_processing(metrics)¶ Procedures done after every train iteration. Similar to
_val_postiter_processing(), except operating on train rather than val variables, and callingvalidate()when appropriate.
-
_val_postiter_processing(record_progress=True, use_callbacks=True, metrics=None, batch_size=None)¶ Procedures done after every validation iteration. Unless marked “always”, are conditional and may skip.
- Update temp val history (always)
- Update
val_datagenstate (always) - Update val cache (preds, labels, etc)
- Update val history
- Reset statefuls
- Print progress
- Apply callbacks
Executes internal “callbacks” when appropriate:
_on_iter_end,_on_batch_end,_on_epoch_end. List not exhaustive.
-
_on_val_end(record_progress, use_callbacks, clear_cache)¶ Procedures done after
validate(). Unless marked “always”, are conditional and may skip. List not exhaustive.- Update train/val history
- Clear cache
- Plot history
- Checkpoint
- Apply callbacks
- Check model health
- Validate
batch_size(always) - Reset validation flags:
_inferred_batch_size,_val_loop_done,_train_loop_done(always)
-
get_data(val=False)¶ Get train (
val=False) or validation (val=True) data fromdatagenorval_datagen, respectively. SeeDataGenerator.DataGenerator.get()returnsx, labels; ifinput_as_data == True, setsy = x- else,y = labels. Either way, setsclass_labels = labels. Generatessample_weightfromclass_labels.
-
clear_cache(reset_val_flags=False)¶ Call to reset cache attributes accumulated during validation; useful for “restarting” validation (before calling
validate()).Attributes set to
[]:{'_preds_cache', '_labels_cache', '_sw_cache', '_class_labels_cach', '_set_name_cache', '_val_set_name_cach', '_y_true', '_val_sw'}.
-
reset_validation()¶ Used to restart validation (e.g. in case interrupted); calls
clear_cache()andDataGenerator.reset_state()(and, ifreset_statefuls,model.reset_states()).Does not reset validation counters (e.g.
_val_iters).
-
_should_do(freq_config, forced=False)¶ Checks whether a counter meets a frequency as specified in
val_freq,plot_history_freq,unique_checkpoint_freq,temp_checkpoint_freq.“Counter” is one of
_fit_iters,_batches_fit,epoch, and_times_validated. Ex: withunique_checkpoint_freq = {'batch': 5},checkpoint()will make a unique checkpoint on every 5th batch fitted duringtrain().
-
_update_val_iter_cache()¶ Called by
_on_iter_endwithin_val_postiter_processing(); updates validation cache variables (_labels_cache,_preds_cache,_class_labels_cache,_sw_cache).If
val_datagenhas a non-Noneslice_idx, will preserve batch-slice structure:>>> [[y00, y01, y02], [y10, y11, y12]] # 2 batches, 3 slices/batch >>> [[y00, y01], [y10, y11], [y20, y21]] # 3 batches, 2 slices/batch
-
_print_train_progress()¶ Called within
_train_postiter_processing(), byon_batch_end().
-
_print_val_progress()¶ Called within
_val_postiter_processing(), byon_batch_end().
-
_print_progress(metrics, endchar='\n')¶ Called by
_print_train_progress()and_print_val_progress().
-
_print_iter_progress(val=False)¶ Called within
train()andvalidate().
-
plot_history(update_fig=True, w=1, h=1)¶ Plots train & validation history (from
historyandval_history).update_fig=True-> store latest fig in_history_fig.w&hscale the width & height, respectively, of the figure.- Plots configured by
plot_configs.
-
_apply_callbacks(stage)¶ Callbacks. See examples/callbacks
- Two approaches:
Class-based: inherit deeptrain.callbacks.TraingenCallback, define stage-based methods, e.g. on_train_epoch_end. Methods also take
stageargument for further control, e.g. to only callon_train_epoch_endwhenstage == ('val_end', 'train:epoch').Function-based: make a dict of stage-function call pairs, e.g.:
>>> {'train:epoch': (fn1, fn2), ... 'val_batch': fn3, ... ('val_end': 'train:epoch'): fn4}
Callback will execute if a key is
inthestagepassed to_apply_callbacks; e.g.(fn1, fn2)will execute onstage==('val_end', 'train:epoch'), with key'train:epoch', butfn4won’t execute, onstage=='train:epoch'.
-
_init_callbacks()¶ Instantiates callback objects (must subclass
TraingenCallback), passing inTrainGeneratorinstance as first (and only) argument. Enables custom callbacks utilizingTrainGeneratorattributes and methods.
-
check_health(dead_threshold=1e-07, dead_notify_above_frac=0.001, large_threshold=3, large_notify_above_frac=0.001, notify_detected_only=True)¶ Check whether any layer weights have ‘zeros’ or NaN weights; very fast / inexpensive.
- Arguments:
- dead_threshold: float
- Count values below this as zeros.
- dead_notify_above_frac: float
- If fraction of values exceeds this, print it and the weight’s name.
- notify_detected_only: bool
- True -> print only if dead/NaN found False -> print a ‘not found’ message
-
destroy(confirm=False, verbose=1)¶ Class ‘destructor’. Sets own, datagen’s, and val_datagen’s attributes to
[](which can free memory of arrays), then deletes them. Also deletes ‘model’ attribute, but this has no effect on memory allocation until it’s dereferenced globally and the TensorFlow/Keras graph is cleared (best bet is to restart the Python kernel).
-
fit_fn¶
-
eval_fn¶
-
epoch¶
-
val_epoch¶
-
_prepare_initial_data(from_load=False)¶ Preloads first batch for training and validation, and superbatch if available.
-
_init_logger()¶ Instantiate log directory for checkpointing. If
logdirwas provided at__init__, will use it - else, will make a directory and assign its absolute path tologdir.
-
_init_and_validate_kwargs(kwargs)¶ Sets and validates
**kwargs, raising exception if kwargs result in an invalid configuration, or correcting them (and possibly notifying) when possible. Also catches unused arguments.
-
_init_class_vars()¶ Instantiates various internal attributes. Most of these are saved and loaded by default.
deeptrain.visuals module¶
-
deeptrain.visuals.binary_preds_per_iteration(_labels_cache, _preds_cache, w=1, h=1)¶ Plots binary preds vs. labels in a heatmap, separated by batches, grouped by slices.
To be used with sigmoid outputs (1 unit). Both inputs are to be shaped
(batches, slices, samples, *).- Arguments:
- _labels_cache: list[np.ndarray]
- List of labels cached during training/validation; insertion order
must match that of
_preds_cache(i.e., first array should correspond to labels of same batch as predictions in_preds_cache). - _preds_cache: list[np.ndarray]
- List of predictions cached during training/validation; see docs
on
_labels_cache. - w, h: float
- Scale figure width & height, respectively.
-
deeptrain.visuals.binary_preds_distribution(_labels_cache, _preds_cache, pred_th, w=1, h=1)¶ Plots binary preds in a scatter plot, labeling dots according to their labels, and showing
pred_thas a vertical line. Positive class (1) is labeled red, negative (0) blue; a red dot far left (close to 0) is hence a strongly misclassified positive class, and vice versa.To be used with sigmoid outputs (1 unit).
- Arguments:
- _labels_cache: list[np.ndarray]
- List of labels cached during training/validation; insertion order
must match that of
_preds_cache(i.e., first array should correspond to labels of same batch as predictions in_preds_cache). - _preds_cache: list[np.ndarray]
- List of predictions cached during training/validation; see docs
on
_labels_cache. - pred_th: float
- Predict threshold (e.g. 0.5), plotted as a vertical line.
- w, h: float
- Scale figure width & height, respectively.
-
deeptrain.visuals.infer_train_hist(model, input_data, layer=None, keep_borders=True, bins=100, xlims=None, fontsize=14, vline=None, w=1, h=1)¶ Histograms of flattened layer output values in inference and train mode. Useful for comparing effect of layers that behave differently in train vs infrence modes (Dropout, BatchNormalization, etc) on model prediction (or intermediate activations).
- Arguments:
- model: models.Model / models.Sequential (keras / tf.keras)
- The model.
- input_data: np.ndarray / list[np.ndarray]
- Data to feed to
modelto fetch outputs. List of arrays for multi-input networks. - layer: layers.Layer / None
- Layer whose outputs to fetch; defaults to last layer (output).
- keep_borders: bool
- Whether to keep the plots’ bounding box.
- bins: int
- Number of histogram bins; kwarg to
plt.hist() - xlims: tuple[float, float] / None
- Histogram x limits. Defaults to min/max of flattened data per plot.
- fontsize: int
- Title font size.
- vline: float / None
- x-coordinate of vertical line to draw (e.g. predict threshold).
- w, h: float
- Scale figure width & height, respectively.
-
deeptrain.visuals.layer_hists(model, _id='*', mode='weights', input_data=None, labels=None, omit_names='bias', share_xy=(0, 0), configs=None, **kw)¶ Histogram grid of layer weights, outputs, or gradients.
- Arguments:
- model: models.Model / models.Sequential (keras / tf.keras)
- The model.
- _id: str / int / list[str/int].
- int -> idx; str -> name
- idx: int. Index of layer to fetch, via model.layers[idx].
- name: str. Name of layer (full or substring) to be fetched. Returns earliest match if multiple found.
- list[str/int] -> treat each str element as name, int as idx.
Ex:
['gru', 2]gets (e.g.) weights of first layer with name substring ‘gru’, then of layer w/ idx 2. '*'(wildcard) -> get (e.g.) outputs of all layers (except input) with ‘output’ attribute.
- mode: str in (‘weights’, ‘outputs’, ‘gradients:weights’, ‘gradients:outputs’)
- Whether to fetch layer weights, outputs, or gradients (w.r.t. outputs or weights).
- input_data: np.ndarray / list[np.ndarray] / None
- Data to feed to
modelto fetch outputs / gradients. List of arrays for multi-input networks. Ignored formode='weights'. - labels: np.ndarray / list[np.ndarray]
- Labels to feed to
modelto fetch gradients. List of arrays for multi-output networks. Ignored formode in ('weights', 'outputs'). - omit_names: str / list[str] / tuple[str]
- Names of weights to omit for
_idspecifying layer names. E.g. forDense,omit_names='bias'will fetch only kernel weights. Ignored formode != 'weights'. - share_xy: tuple[bool, bool] / tuple[str, str]
- Whether to share x or y limits in histogram grid, respectively.
kwarg to
plt.subplots(); can be'col'or'row'for sharing along rows or columns, respectively. - configs: dict / None
kwargs to customize various plot schemes:
'plot': passed partly toax.hist()insee_rnn.hist_clipped(); includepeaks_to_clipto adjust ylims with a number of peaks disregarded. Seehelp(see_rnn.hist_clipped). ax = subplots axis'subplot': passed toplt.subplots()'title': passed tofig.suptitle(); fig = subplots figure'tight': passed tofig.subplots_adjust()'annot': passed toax.annotate()'save': passed tofig.savefig()ifsavepathis not None.
- kw: dict / kwargs
- kwargs passed to
see_rnn.features_hist.
-
deeptrain.visuals.viz_roc_auc(y_true, y_pred)¶ Plots the Receiver Operator Characteristic curve.
-
deeptrain.visuals.get_history_fig(self, plot_configs=None, w=1, h=1)¶ Plots train / validation history according to
plot_configs.- Arguments:
- plot_configs: dict / None
- See
_DEFAULT_PLOT_CFG. If None, defaults toTrainGenerator.plot_configs(which itself defaults to_PLOT_CFGinconfigs.py). - w, h: float
- Scale figure width & height, respectively.
plot_configsis structured as follows:>>> {'fig_kw': fig_kw, ... '0': {reserved_name: value, ... plt_kw: value}, ... '1': {reserved_name: value, ... plt_kw: value}, ... ...}
fig_kw: dict, passed toplt.subplots(**fig_kw)reserved_name: str, one of('metrics', 'x_ticks', 'vhlines', 'mark_best_cfg', 'ylims', 'legend_kw'). Used to configure supported custom plot behavior (see “Builtin plot customs” below).plt_kw: str, name of kwarg to pass directly toplt.plot().value: depends on key; see defaultplot_configsin_DEFAULT_PLOT_CFGandmisc._make_plot_configs_from_metrics().
Only
'metrics'and'x_ticks'keys are required for each dict - others have default values.Builtin plot customs: (
reserved_name)'metrics'(required): names of metrics to plot from histories, as{'train': train_metrics, 'val': val_metrics}(at least one metric name required, for only one of train/val - need to have “something” to plot).x_ticks'(required): x-coordinates of respective metrics, of samelen().'vhlines': dict[‘v’ / ‘h’: float]. vertical/horizontal lines; e.g.{'v': 10}will draw a vertical line at x = 10, and{'h': .5}at y = .5.'mark_best_cfg':{'train': metric_name}or{'val': metric_name}and (optional){'max_is_best: bool}pairs. Will mark plot to indicate a metric optimum (max (if'max_is_best', the default) or min).'ylims': y-limits of plot panes.'legend_kw': passed toplt.legend(); if None, no legend is drawn.
Defaults handling:
Keys and subkeys, where absent, will be filled from configs returned by
misc._make_plot_configs_from_metrics().- If plot pane
'0'is lacking entirely, it’ll be copied from the defaults. - If subkey
'color'in dict with key'0'is missing, will fill fromdefaults['0']['color'].
Further info:
- Every key’s iterable value (list, etc) must be of same len as number of
metrics in
'metrics'; this is ensured withincfg_fn. - Metrics are plotted in order of insertion (at both dict and list level),
so later metrics will carry over to additional plot panes if number of
metrics exceeds
plot_first_pane_max_vals; seecfg_fn. - A convenient option is to change
_PLOT_CFGinconfigs.pyand passplot_configs=NonetoTrainGenerator.__init__; will internally callcfg_fn, which validates some configs and tries to fill what’s missing. - Above,
cfg_fn==misc._make_plot_configs_from_metrics()
-
deeptrain.visuals._plot_metrics(x_ticks, metrics, plot_kw, mark_best_idx=None, max_is_best=True, axis=None, vhlines=None, ylims=(0, 2), legend_kw=None, key_metric='loss', metric_name_to_alias_fn=None)¶ Plots metrics according to inputs passed by
get_history_fig().
Module contents¶
-
deeptrain.scalefig(fig)¶ Used internally to scale figures according to env var ‘SCALEFIG’.
os.environ[‘SCALEFIG’] can be an int, float, tuple, list, or bracketless tuple, but must be a string: ‘1’, ‘1.1’, ‘(1, 1.1)’, ‘1,1.1’.
-
deeptrain.set_seeds(seeds=None, reset_graph=False, verbose=1)¶ Sets random seeds and maybe clears keras session and resets TensorFlow default graph.
NOTE: after calling w/
reset_graph=True, best practice is to re-instantiate model, else some internal operations may fail due to elements from different graphs interacting (of pre-reset model and post-reset tensors).