deeptrain package

Submodules

deeptrain.callbacks module

deeptrain.callbacks.binary_preds_per_iteration_cb(self)

Suited for binary classification sigmoid outputs. See binary_preds_per_iteration().

deeptrain.callbacks.binary_preds_distribution_cb(self)

Suited for binary classification sigmoid outputs. See binary_preds_distribution().

deeptrain.callbacks.infer_train_hist_cb(self)

Suited for binary classification sigmoid outputs. See infer_train_hist().

deeptrain.callbacks.make_layer_hists_cb(_id='*', mode='weights', x=None, y=None, omit_names='bias', share_xy=(0, 0), configs=None, **kw)

Layer histograms grid callback. See layer_hists().

class deeptrain.callbacks.TraingenCallback

Bases: object

Required base class for callbacks objects used by TrainGenerator. Enables using TrainGenerator attributes by assigning it to self via init_with_traingen. Methods are called by TrainGenerator at several stages: train, validation, save, load, __init__.

stage is an optional argument to every method (except init_with_traingen) to allow finer-grained control, particularly for ('val_end', 'train:epoch').

Methods not implemented by the inheriting class will be skipped by catching NotImplementedError.

__init__ sets a private attribute _counters, which can be used along a “freq” to call methods every Nth callback, instead of every callback. Example: RandomSeedSetter.

init_with_traingen(traingen=None)

Called by TrainGenerator.__init__(), passing in self (TrainGenerator instance).

on_train_iter_end(stage=None)

Called by _on_iter_end, within TrainGenerator._train_postiter_processing() with stage='train:iter'.

on_train_batch_end(stage=None)

Called by _on_batch_end, within TrainGenerator._train_postiter_processing(), with stage='train:batch'.

on_train_epoch_end(stage=None)

Called by _on_epoch_end, within TrainGenerator._train_postiter_processing(), with stage='train:epoch'.

on_val_iter_end(stage=None)

Called by _on_iter_end, within TrainGenerator._val_postiter_processing(), with stage='val:iter'.

on_val_batch_end(stage=None)

Called by _on_batch_end, within TrainGenerator._val_postiter_processing(), with stage='val:batch'.

on_val_epoch_end(stage=None)

Called by _on_epoch_end, within TrainGenerator._val_postiter_processing(), with stage='val:epoch'.

on_val_end(stage=None)

Called by TrainGenerator._on_val_end(), with:

  • stage=('val_end', 'train:epoch') if TrainGenerator.datagen.all_data_exhausted
  • stage='val_end' otherwise
on_save(stage=None)

Called by TrainGenerator.save() with stage='save'.

on_load(stage=None)

Called by TrainGenerator.load() with stage='load'.

class deeptrain.callbacks.RandomSeedSetter(seeds=None, freq={'train:epoch': 2})

Bases: deeptrain.callbacks.TraingenCallback

Periodically sets random, numpy, TensorFlow graph-level, and TensorFlow global seeds.

Arguments:
seeds: dict / None
Dict of initial seeds; if None, will default all to 0. E.g.: {'random': 0, 'numpy': 0, 'tf-graph': 1, 'tf-global': 2}.
freq: dict
When to set / reset the seeds. E.g. {'train:epoch': 2} (default) will set every 2 train epochs. By default, seeds are incremented by 1 to avoid repeating random sequences with each setting.

Methods can be overridden to increment by a different amount or not at all, and to clear keras session & reset TF default graph, which doesn’t happen by default.

seeds
set_seeds(increment=0, seeds=None, reset_graph=False, verbose=1)

Sets seeds. increment will add to each of self.seeds and update it. If seeds is not None, will override increment.

classmethod _set_seeds(seeds=None, reset_graph=False, verbose=1)

See set_seeds(); _set_seeds can be used without instantiating class, as is used in deeptrain.set_seeds().

classmethod reset_graph(verbose=1)

Clears keras session, and resets TensorFlow default graph.

NOTE: after calling, best practice is to re-instantiate model, else some internal operations may fail due to elements from different graphs interacting (from docs: “using any previously created tf.Operation or tf.Tensor objects after calling this function will result in undefined behavior”).

call_on_freq()
on_train_iter_end(stage)
on_train_batch_end(stage)
on_train_epoch_end(stage)
on_val_iter_end(stage)
on_val_batch_end(stage)
on_val_epoch_end(stage)
on_val_end(stage)
on_save(stage)
on_load(stage)
class deeptrain.callbacks.VizAE2D(n_images=8, save_images=False)

Bases: deeptrain.callbacks.TraingenCallback

Image AutoEncoder reconstruction visualizer.

Plots n_images images of original & reconstructed (model outputs) in two rows, side-by-side vertically.

on_val_end(stage=None)

Called by TrainGenerator._on_val_end(), with:

  • stage=('val_end', 'train:epoch') if TrainGenerator.datagen.all_data_exhausted
  • stage='val_end' otherwise
viz()
get_data()
class deeptrain.callbacks.TraingenLogger(savedir, configs, loadpath=None, get_data_fn=None, get_labels_fn=None, gather_fns=None, logname='datalog_', init_log_id=None)

Bases: deeptrain.callbacks.TraingenCallback

TensorBoard-like logger, gathering layer weights, outputs, and gradients over specified periods.

Arguments:
savedir: str
Path to directory where to save data and class instance state.
configs: dict

Data mode-layer/weight name pairs for logging. Ex:

>>> configs = {
...     'weights': ['dense_2', 'conv2d_1/kernel:0'],
...     'outputs': 'lstm_1',
...     'gradients': ('conv2d_1', 'dense_2:/bias:0'),
...     'gradients-kw': dict(mode='weights', learning_phase=0),
... }

With mode='weights' for ('weights', 'gradients'), complete weight names may be specified - else, all layer weights will be included. If layer name is substring, will return earliest match. See see_rnn.inspect_gen methods.

loadpath: str / None
Path to savefile of class instance state, to resume logging. If None, will attempt to fetch from savedir based on logname.
get_data_fn: function
Will call to fetch input data to feed for 'gradients' and 'outputs' data. Defaults to lambda: TrainGenerator.val_datagen.get()[0].
get_labels_fn: function
Will call to fetch labels to feed for 'gradients' data. Defaults to lambda: TrainGenerator.val_datagen.get()[1] if TrainGenerator.input_as_labels == False - else, to ~get()[0].
logname: str
Base name of savefiles, prefixing save _id.
init_log_id: int / None
Initial log _id, will increment by 1 with each call to log(). (unless _id was passed to log()). Defaults to -1.
init_with_traingen(traingen)

Instantiates configs, getter functions, & others; requires a TrainGenerator instance.

log(_id=None)

Gathers data according to configs and gather_fns.

save(_id=None, clear=False, verbose=1)

Saves data per _loggables, but not other class instance attributes.

Arguments:
_id: int / str[int] / None
Appended to logname to make savepath in savedir.
clear: bool
Whether to empty data attributes after saving.
verbose: bool / int[bool]
Whether to print a save message with savepath.
load(verbose=1)

Loads from an .h5 file according to loadpath, which can include data and other class instance attributes. If loadpath is None, will attempt to fetch from savedir based on logname.

verbose == 1/True to print load message with loadpath.

clear(verbose=1)

Set attributes in _loggables to {}

deeptrain.data_generator module

class deeptrain.data_generator.DataGenerator(data_path, batch_size=32, labels_path=None, preprocessor=None, preprocessor_configs=None, data_loader=None, labels_loader=None, preload_labels=None, shuffle=False, superbatch_path=None, set_nums=None, superbatch_set_nums=None, **kwargs)

Bases: object

Central interface between a directory and TrainGenerator. Handles data loading, preprocessing, shuffling, and batching. Requires only data_path to run.

Arguments:
data_path: str
Path to directory to load data from.
batch_size: int
Number of samples to feed the model at once. Can differ from size of batches of loaded files; see “Dynamic batching”.
labels_path: str / None
Path to labels file. If None, will not load labels; can be used with TrainGenerator.input_as_labels = True, feeding batch as labels in TrainGenerator.get_data() (e.g. autoencoders).
preprocessor: None / custom object / str in (‘timeseries’,)

Transforms batch and labels right before both are returned by get(). See _set_preprocessor().

preprocessor_configs: None / dict
Kwargs to pass to preprocessor in case it’s None, str, or an uninstantiated custom object. Ignored if preprocessor is instantiated.
data_loader: None / function / DataLoader()

Object for loading data from directory / file.

labels_loader: None / function / DataLoader()
data_loader, but for labels.
preload_labels: bool / None
Whether to load all labels into all_labels at __init__. Defaults to True if labels_path is a file or a directory containing a single file.
shuffle: bool
If True, reset_state() will shuffle set_nums_to_process; the method is called by _on_epoch_end within TrainGenerator._train_postiter_processing() and TrainGenerator._val_postiter_processing() (via on_epoch_end()).
superbatch_path: str / None
Path to file or directory from which to load superbatch (preload_superbatch()); see “Control Flow”.
set_nums: list[int] / None
Used to set set_nums_original and set_nums_to_process. If None, will infer from data_path; see “Control Flow”.
superbatch_set_nums: list[int] / None
set_nums to load into superbatch; see “Control Flow”.

How it works:

Data is fed to TrainGenerator() via DataGenerator(). To work, data:

  • must be in one directory (or one file with all data)
  • file extensions must be same (.npy, .h5, etc)
  • file names must be enumerated with a common name (data1.npy, data2.npy, …)
  • file batch size (# of samples, or dim 0 slices) should be same, but can also be in integer or fractal multiples of (x2, x3, x1/2, x1/3, …)
  • labels must be in one file - unless feeding input as labels (e.g. autoencoder), which doesn’t require labels files; just pass TrainGenerator(input_as_labels=True)

Flexible batch_size:

Loaded file’s batch size may differ from batch_size, so long as former is an integer or integer fraction multiple of latter. Ex:

  • len(loaded) == 32, batch_size == 64 -> will load another file and concatenate into len(batch) == 64.
  • len(loaded) == 64, batch_size == 32 -> will set first half of loaded as batch and cache loaded, then repeat for second half.
  • ‘Technically’, files need not be integer (/ fraction) multiples, as the following load order works with batch_size == 32: len(loaded) == 31, len(loaded) == 1 - but this is not recommended, as it’s bound to fail if using shuffling, or if total number of samples isn’t divisible by batch_size. Other problems may also arise.

Control Flow:

  • set_num: index / Mapping key used to identify and get batch and labels via load_data() and load_labels(). Ex: for DataLoader.numpy_loader(), which expects files shaped (batches, samples, *), it’d do np.load(path)[set_num].
  • set_nums_to_process: will pop set_num from this list until it’s empty; once empty, will set all_data_exhausted=True (update_state()).
  • set_nums_original: will reset set_nums_to_process to this with reset_state(). It’s either set_nums passed to __init__ or is inferred in _set_set_nums() as all available set_nums in data_path file / directory.
  • superbatch: dict of set_num-batch`es loaded persistently in memory (RAM) as opposed to `batch, which is overwritten. Once loaded, batch can be drawn straight from superbatch if set_num is in it (i.e. superbatch_set_nums).
  • get() returns batch and labels fed through Preprocessor.process().
  • advance_batch() gets “next” batch and labels. “Next” is determined by set_num, which is popped from set_nums_to_process[0].
  • batch_exhausted: signals TrainGenerator() that a batch was consumed; this information is set via update_state() per _on_iter_end within TrainGenerator._train_postiter_processing() or TrainGenerator._val_postiter_processing().
  • If using slices (slices_per_batch is not None), then batch_exhausted is set to True only when slice_idx == slices_per_batch - 1.

__init__:

Instantiation. (“+” == if certain conditions are met)

  • +Infers missing configs based on args
  • Validates args & kwargs, and tries to correct, printing a”NOTE” or “WARNING” message where appropriate
  • +Preloads all labels into all_labels
  • Instantiates misc internal parameters to predefiend values (may be overridden by TrainGenerator loading).
_BUILTINS = {'extensions': {'.csv', '.h5', '.npy'}, 'loaders': {'csv', 'hdf5', 'numpy', 'numpy-lz4f', 'numpy-memmap'}, 'preprocessors': (<class 'deeptrain.util.preprocessors.GenericPreprocessor'>, <class 'deeptrain.util.preprocessors.TimeseriesPreprocessor'>)}
get(skip_validation=False)

Returns (batch, labels) fed to Preprocessor.process().

skip_validation: bool
advance_batch(forced=False, is_recursive=False)

Sets next batch and labels; handles dynamic batching.

  • If batch_loaded and not forced (and not is_recursive), prints a warning that batch is loaded, and returns (does nothing)
  • len(batch) != batch_size:
    • < batch_size: calls advance_batch() with is_recursive = True. With each such call, batch and labels are extended (stacked) until matching batch_size.
    • > batch_size, not integer multiple: raises Exception.
    • > batch_size, is integer multiple: makes _group_batch and _group_labels, which are used to set batch and labels.
  • +If set_nums_to_process is empty, will raise Exception; it must have been reset beforehand via e.g. reset_state(). If it’s not empty, sets set_num by popping from set_nums_to_process. (+: only if _group_batch is None)
  • Sets or extends batch via _get_next_batch() (by loading, or from _group_batch or superbatch).
  • +Sets or extends labels via _get_next_labels() (by loading, or from _group_labels, or all_labels). (+: only if labels_path is a path (and not None))
  • Sets set_name, used by TrainGenerator() to print iteration messages.
  • Sets batch_loaded = True, batch_exhausted = False, all_data_exhausted = False, and slice_idx to None if it’s already None (else to 0).
_get_next_batch(set_num=None, warn=True)

Gets batch per set_num.

  • set_num = None: will use self.set_num.
  • warn = False: won’t print warning on superbatch not being preloaded.
  • If _group_batch is not None, will get batch from _group_batch.
  • If set_num is in superbatch_set_nums, will get batch as superbatch[set_num] (if superbatch exists).
  • By default, gets batch via load_data().
_get_next_labels(set_num=None)

Gets labels per set_num.

  • set_num = None: will use self.set_num.
  • If _group_labels is not None, will get labels from _group_labels.
  • If set_num is in superbatch_set_nums, will get batch as superbatch[set_num] (if superbatch exists).
  • By default, gets labels via load_data(), if labels_path is set - else, labels=[].
_batch_from_group_batch()

Slices _group_batch per batch_size and _group_batch_idx.

_labels_from_group_labels()

Slices _group_labels per batch_size and _group_batch_idx.

_update_group_batch_state()

Sets “group” attributes to None once sufficient number of batches were extracted, else increments _group_batch_idx.

on_epoch_end()

Increments epoch, calls preprocessor.on_epoch_end(epoch), then reset_state(), and returns epoch.

update_state()

Calls preprocessor.update_state(), and if batch_exhausted and set_nums_to_process == [], sets all_data_exhausted = True to signal TrainGenerator() of epoch end.

reset_state(shuffle=None)

Calls preprocessor.reset_state(), sets batch_exhausted = True, batch_loaded = False, resets set_nums_to_process to set_nums_original, and shuffles set_nums_to_process if shuffle.

_validate_batch()

If all_data_exhausted, calls reset_state(). If batch_exhausted, calls advance_batch().

_make_group_batch_and_labels(n_batches)

Makes _group_batch and _group_labels when loaded len(batch) exceeds batch_size as its integer multiple. May shuffle.

  • _group_batch = np.asarray(batch), and _group_labels = np.asarray(labels); each’s len() > batch_size.

  • Shuffles if:
    • shuffle_group_samples: shuffles all samples (dim0 slices)

    • shuffle_group_batches: groups dim0 slices by batch_size, then shuffles the groupings. Ex:

      >>> batch_size == 32
      >>> batch.shape == (128, 100)
      >>> batch = batch.reshape()  # (4, 32, 100) == .shape
      >>> shuffle(batch)           # 24 (4!) permutations
      >>> batch = batch.reshape()  # (128, 100)   == .shape
      
  • Sets _group_batch_idx = 0, and calls _update_group_batch_state().

  • Doesn’t affect labels if labels_path is falsy (e.g. None)

batch_exhausted

Is retrieved from and set in preprocessor. Indicates that batch and labels for given set_num were consumed by TrainGenerator (if using slices, that all slices were consumed).

Ex: self.batch_exhausted = 5 will set self.preprocessor.batch_exhausted = 5, and print(self.batch_exhausted) will then print 5 (or something else if preprocessor changes it internally).

batch_loaded

Is retrieved from and set in preprocessor, same as batch_exhausted. Indicates that batch and labels for given set_num are loaded.

slices_per_batch

Is retrieved from and set in preprocessor, same as batch_exhausted.

slice_idx

Is retrieved from and set in preprocessor, same as batch_exhausted.

load_data

Load and return batch data via data_loaders.DataLoader.load_fn(). Used by _get_next_batch() and preload_superbatch().

load_labels

Load and return labels data via data_loaders.DataLoader.load_fn(). Used by _get_next_labels() and preload_all_labels().

_infer_info(path)

Infers unspecified essential attributes from directory and contained files info:

  • Checks that the data directory (path) isn’t empty (files whose names start with '.' aren’t counted)
  • Retrieves data filepaths per path and gets data extension (to most frequent ext in dir, excl. “other path” from count if in same dir. “other path” is data_path if path == labels_path, and vice versa.)
  • Gets base_name as longest common substring among files with ext extension
  • If path is path to a file, then filepaths=[path].
  • If path is None, returns base_name=None ext=None, filepaths=[].
_infer_and_set_info(data_loader, labels_loader)

Sets data_loader and labels_loader (DataLoader()), using info obtained from _infer_info().

  • If info contains only one filepath, loader will operate with _is_dataset=True.
  • If preload_labels is None and labels_loader._is_dataset, will set preload_labels=True.

data_loader / labels_loader are:

_set_set_nums(set_nums, superbatch_set_nums)

Sets set_nums_original, set_nums_to_process, and superbatch_set_nums.

  • Fetches set_nums via DataLoader._get_set_nums()
  • Sets set_nums_to_process and set_nums_original; if set_nums weren’t passed to __init__, sets to fetched ones.
  • If set_nums were passed, validates that they’re a subset of fetched ones (i.e. can be seen by data_loader).
  • Sets superbatch_set_nums; if not passed to __init__, and == 'all', sets to fetched ones. If passed, validates that they subset fetched ones.
  • Does not validate set_nums from labels_loader’s perspective; user is expected to supply a labels to each batch with common set_num.
_set_preprocessor(preprocessor, preprocessor_configs)

Sets preprocessor, based on preprocessor passed to __init__:

  • If None, sets to GenericPreprocessor(), instantiated with preprocessor_configs.
  • If an uninstantiated class, will validate that it subclasses Preprocessor(), then isntantiate with preprocessor_configs.
  • If string, will match to a supported builtin.
  • Validates that the set preprocessor subclasses Preprocessor().
preload_superbatch()

Loads all data specified by superbatch_set_nums via load_data(), and assigns them to superbatch for each set_num.

preload_all_labels()

Loads all labels into all_labels using load_labels(), based on set_nums_original.

_init_and_validate_kwargs(kwargs)

Sets and validates kwargs passed to __init__.

  • Ensures data_path is a file or a directory, and labels_path is a file, directory, or None.
  • Ensures kwargs are functional (compares against names in _DEFAULT_DATAGEN_CFG.
  • Sets whichever names were passed with kwargs, and defaults the rest.
_init_class_vars()

Instantiates various internal attributes. Most of these are saved and loaded by TrainGenerator() by default.

deeptrain.introspection module

deeptrain.introspection.compute_gradient_norm(self, input_data, labels, sample_weight=None, learning_phase=0, _id='*', mode='weights', norm_fn=(<ufunc 'sqrt'>, <ufunc 'square'>), scope='local')

Computes gradients w.r.t. layer weights or outputs per _id, and returns norm according to norm_fn and scope.

Arguments:
input_data: np.ndarray / list[np.ndarray] / supported formats
Data w.r.t. which loss is to be computed for the gradient. List of arrays for multi-input networks. “Supported formats” is any valid input to model.
labels: np.ndarray / list[np.ndarray] / supported formats
Labels w.r.t. which loss is to be computed for the gradient.
sample_weight: np.ndarray / list[np.ndarray] / supported formats
kwarg to model.fit(), etc., weighting individual sample losses.
learning_phase: bool / int[bool]
  • 1: use model in train mode
  • 0: use model in inference mode
_id: str / int / list[str/int].
  • int -> idx; str -> name
  • idx: int. Index of layer to fetch, via model.layers[idx].
  • name: str. Name of layer (full or substring) to be fetched. Returns earliest match if multiple found.
  • list[str/int] -> treat each str element as name, int as idx. Ex: ['gru', 2] gets (e.g.) weights of first layer with name substring ‘gru’, then of layer w/ idx 2.
  • '*' (wildcard) -> get (e.g.) outputs of all layers (except input) with ‘output’ attribute.
mode: str in (‘weights’, ‘outputs’, ‘gradients:weights’, ‘gradients:outputs’)
Whether to fetch layer weights, outputs, or gradients (w.r.t. outputs or weights).
norm_fn: (function, function) / function
Norm function(s) to apply to gradients arrays when gathering. (np.sqrt, np.square) for L2-norm, np.abs for L1-norm. Computed as: outer_fn(sum(inner_fn(x) for x in data)), where outer_fn, inner_fn = norm_fn if norm_fn is list/tuple, and inner_fn = norm_fn and outer_fn = lambda x: x otherwise.
scope: str in (‘local’, ‘global’)
Whether to apply stat_fn on individual gradient arrays, or sum of.
Returns:
Gradient norm(s). List of float if scope == 'local' (norms of weights), else float (outer_fn(sum(sum(inner_fn(g)) for g in grads))).

TensorFlow optimizers do gradient clipping according to the clipnorm setting by comparing individual weights’ L2-norms against clipnorm, and rescaling if exceeding. These L2 norms can be obtained using norm_fn=(np.sqrt, np.square) with scope == 'local' and mode='weights'. See:

  • tensorflow.python.keras.optimizer_v2.optimizer_v2._clip_gradients
  • keras.optimizers.clip_norm
  • tensorflow.python.ops.clip_ops.clip_by_norm
deeptrain.introspection.gradient_norm_over_dataset(self, val=False, learning_phase=0, mode='weights', norm_fn=(<ufunc 'sqrt'>, <ufunc 'square'>), stat_fn=<function median>, n_iters=None, prog_freq=10, w=1, h=1)

Aggregates gradient norms over dataset, one iteration at a time. Useful for estimating value of gradient clipping, clipnorm, to use. Plots a histogram of gathered data when finished. Also see compute_gradient_norm().

Arguments:
val: bool
  • True: gather over val_datagen batches
  • False: gather over datagen batches
learning_phase: bool / int[bool]
  • True: get gradients of model in train mode
  • False: get gradients of model in inference mode
mode: str in (‘weights’, ‘outputs’)
Whether to get gradients with respect to layer weights or outputs.
norm_fn: (function, function) / function
Norm function(s) to apply to gradients arrays when gathering. (np.sqrt, np.square) for L2-norm, np.abs for L1-norm. Computed as: outer_fn(sum(inner_fn(g) for g in grads)), where outer_fn, inner_fn = norm_fn if norm_fn is list/tuple, and inner_fn = norm_fn and outer_fn = lambda x: x otherwise.
stat_fn: function
Aggregate function to apply on computed norms. If np.mean, will gather mean of gradients; if np.median, the median, etc. Computed as: stat_fn(outer_fn(sum(inner_fn(g) for g in grads))).
n_iters: int / None
Number of expected iterations over entire dataset. Can be used to iterate over subset of entire dataset. If None, will return upon DataGenerator.all_data_exhausted.
prog_freq: int
How often to print f'|{batch_idx}', and '.' otherwise, in terms of number of batches (not iterations, but are same if not using slices). E.g. 5: ....|5....|10....|15.
w, h: float
Scale figure width & height, respectively.
Returns:
grad_norms: np.ndarray
Norms of gradients for every iteration. Shape: (iters_processed, n_params), where n_params is number of gradient arrays whose norm stats were computed at each iteration.
batches_processed: int
Number of batches processed.
iters_processed: int
Number of iterations processed (if using e.g. 4 slices per batch, will equal 4 * batches_processed).
deeptrain.introspection.gradient_sum_over_dataset(self, val=False, learning_phase=0, mode='weights', n_iters=None, prog_freq=10, plot_kw={})

Computes cumulative sum of gradients over dataset, one iteration at a time, preserving full array shapes. Useful for computing mean of gradients over dataset, or other aggregate metrics.

Arguments:
val: bool
  • True: gather over val_datagen batches
  • False: gather over datagen batches
learning_phase: bool / int[bool]
  • True: get gradients of model in train mode
  • False: get gradients of model in inference mode
mode: str in (‘weights’, ‘outputs’)
Whether to get gradients with respect to layer weights or outputs.
n_iters: int / None
Number of expected iterations over entire dataset. Can be used to iterate over subset of entire dataset. If None, will return upon DataGenerator.all_data_exhausted.
prog_freq: int
How often to print f'|{batch_idx}', and '.' otherwise, in terms of number of batches (not iterations, but are same if not using slices). E.g. 5: ....|5....|10....|15.
plot_kw: dict
Kwargs to pass to see_rnn.features_hist; defaults to {'share_xy': False, 'center_zero': True}.
Returns:
grad_sum: dict[str: np.ndarray]
Gradient arrays summed over dataset. Structure: {name: array, name: array, ...}, where name is name of weight array or layer output.
batches_processed: int
Number of batches processed.
iters_processed: int
Number of iterations processed (if using e.g. 4 slices per batch, will equal 4 * batches_processed).
deeptrain.introspection._gather_over_dataset(self, gather_fn, val=False, n_iters=None, prog_freq=10)

Iterates over DataGenerator, applying gather_fn to every batch (or slice). Stops after n_iters, or when DataGenerator.all_data_exhausted if n_iters is None. Useful for monitoring quantities over the course of training or inference,.

gather_fn recursively updates data; as such, it can be used to append to a list, update a dictionary, operate on an array, etc. Review source code for exact logic.

deeptrain.introspection._make_gradients_fn(model, learning_phase, mode, return_names=False)

Makes reusable gradient-getter function, separately for TF Eager & Graph execution. Eager variant is pseudo-reusable; gradient tensors are still fetched all over - graph should be significantly faster.

deeptrain.introspection.print_dead_weights(model, dead_threshold=1e-07, notify_above_frac=0.001, notify_detected_only=False)

Print names of dead weights and their proportions. Useful for debugging vanishing and exploding gradients, or quantifying sparsity.

Arguments:
model: models.Model / models.Sequential (keras / tf.keras)
The model.
dead_threshold: float
Threshold below which to count the weight as “dead”, in absolute value.
notify_above_frac: float
Print only if fraction of weights counted “dead” exceeds this (e.g. if there are 11 absolute values < dead_threshold out of 1000).
notify_detected_only: bool
  • True: print text only if dead weights are discovered
  • False: print a “not found given thresholds” message when appropriate
deeptrain.introspection.print_nan_weights(model, notify_detected_only=False)

Print names of NaN/Inf weights and their proportions. Useful for debugging exploding or buggy gradients.

Arguments:
model: models.Model / models.Sequential (keras / tf.keras)
The model.
notify_detected_only: bool
  • True: print text only if dead weights are discovered
  • False: print a “none found” message if no NaNs were found
deeptrain.introspection.print_large_weights(model, large_threshold=3, notify_above_frac=0.001, notify_detected_only=False)

Print names of weights in excess of set absolute value, and their proportions; excludes Inf. Useful for debugging exploding or buggy gradients.

Arguments:
model: models.Model / models.Sequential (keras / tf.keras)
The model.
large_threshold: float
Threshold above which to count the weight’s absolute value as “large”.
notify_above_frac: float

Print only if fraction of weights counted “large” exceeds this (e.g. if there are 11 absolute values < large_threshold

out of 1000).
notify_detected_only: bool
  • True: print text only if dead weights are discovered
  • False: print a “none found” message if no NaNs were found
deeptrain.introspection._sample_weight_built(model)

In Graph execution, model._feed_sample_weights isn’t built unless model is compiled with sample_weight_mode set, or train_on_batch or test_on_batch is called w/ sample_weight passed.

deeptrain.introspection.interrupt_status(self) -> (<class 'bool'>, <class 'bool'>)

Prints whether TrainGenerator was interrupted (e.g. KeyboardInterrupt, or via exception) during train() and validate(). Returns bools (True for interrupted, else False) for each, as (train, val).

Not foolproof; user can set flags manually or via callbacks. For further assurance, check temp_history, val_temp_history, and cache attributes (e.g. _preds_cache) which are cleared at end of validate() by default; this method checks only flags: _train_loop_done, train_postiter_processed, _val_loop_done, _val_postiter_processed.

deeptrain.introspection.info(self)

Prints various useful TrainGenerator & DataGenerator attributes, and interrupt status.

deeptrain.metrics module

deeptrain.metrics.f1_score(y_true, y_pred, pred_threshold=0.5, beta=1)
deeptrain.metrics.mean_squared_error(y_true, y_pred, sample_weight=1)
deeptrain.metrics.mean_absolute_error(y_true, y_pred, sample_weight=1)
deeptrain.metrics.roc_auc_score(y_true, y_pred)

deeptrain.preprocessing module

deeptrain.preprocessing.data_to_hdf5(savepath, batch_size, loaddir=None, data=None, shuffle=False, compression='lzf', dtype=None, load_fn=None, oversample_remainder=True, batches_dim0=False, overwrite=None, verbose=1)

Convert data to hdf5-group (.h5) format, in batch_size sample sets.

Arguments:
savepath: str
Absolute path to where to save file.
batch_size: int
Number of samples (dim0 slices) to save per file.
loaddir: str
Absolute path to directory from which to load data.
data: np.ndarray / list[np.ndarray]
Shape: (samples, *) or (batches, samples, *) (must use batches_dim0=True. With former, if len(data) == 320 and batch_size == 32, will make a 10-set .h5 file.
shuffle: bool
Whether to shuffle samples (dim0 slices).
compression: str
Compression type to use. kwarg to h5py.File().create_dataset().
dtype: str / np.dtype
Savefile dtype; kwarg to .create_dataset(). Defaults to data’s dtype.
load_fn: function / callable
Used on supported paths (.npy) in loaddir to load data.
oversample_remainder: bool. Relevant only when passing data.
  • True -> randomly draw (remainer - batch_size) samples to fill incomplete batch.
  • False -> drop remainder samples.
batches_dim0: bool
Assume shapes - True: (batches, samples, *); False: (samples, *).
overwrite: bool / None

If savepath file exists,

  • True -> replace it
  • False -> don’t replace it
  • None -> ask confirmation via user input
verbose: bool
Whether to print preprocessing progress.

Notes:

  • If supplying loaddir instead of data, will iteratively load files with supported format (.npy). len() of loaded file must be an integer fraction multiple of batch_size, <= 1. So batch_size == 32 and len() == 16 works, but len() == 48 or len() == 24 doesn’t.
deeptrain.preprocessing.numpy_to_lz4f(data, savepath=None, level=9, overwrite=None)

Do lz4-framed compression on data. (Install compressor via !pip install py-lz4framed)

Arguments:
data: np.ndarray
Data to compress.
savepath: str
Path to where to save file.
level: int
1 to 9; higher = greater compression
overwrite: bool

If savepath file exists,

  • True -> replace it
  • False -> don’t replace it
  • None -> ask confirmation via user input
Returns:
np.ndarray - compressed array.

Example:

>>> numpy_to_lz4f(savedata, savepath=path)
...
>>> # load & decompress
>>> bytes_npy = lz4f.decompress(np.load(path))
>>> loaddata = np.frombuffer(bytes_npy,
...                          dtype=savedata.dtype,  # must be original's
...                          ).reshape(*savedata.shape)
deeptrain.preprocessing.numpy_data_to_numpy_sets(data, labels, savedir=None, batch_size=32, shuffle=True, data_basename='batch', oversample_remainder=True, overwrite=None, verbose=1)

Save data in batch_size chunks, possibly shuffling samples.

Arguments:
data: np.ndarray
Data to batch along labels & save.
labels: np.ndarray
Labels to batch along data and save.
savedir: str / None
Directory in which to save processed data. If None, won’t save.
batch_size: int
Number of samples (dim0 slices) to form a ‘set’ with
data_basename: str
Will save with this prepending set numbering - e.g.: ‘batch__1.npy’, ‘batch__2.npy’ …
oversample_remainder: bool
  • True -> randomly draw (remainer - batch_size) samples to fill incomplete batch.
  • False -> drop remainder samples.
overwrite: bool / None

If savepath file exists,

  • True -> replace it
  • False -> don’t replace it
  • None -> ask confirmation via user input
verbose: bool
Whether to print preprocessing progress.
Returns:
data, labels: processed data & labels.
deeptrain.preprocessing.numpy2D_to_csv(data, savepath=None, batch_size=None, columns=None, sample_dim=1, overwrite=None)

Save 2D data as .csv.

Arguments:
data: np.ndarray
Data to save, shaped (batches, samples).
savepath: str
Path to where save to file.
batch_size: int
Number of rows per column; can differ from data’s.
columns: list of str
Column names for the data frame; defaults to enumerate columns (0, 1, 2, …)
sample_dim: int
Dimension applicable to batch_size.
overwrite: bool

If savepath file exists,

  • True -> replace it
  • False -> don’t replace it
  • None -> ask confirmation via user input

Example:

Suppose we have labels, 16 per batch, and 8 batches. Each batch is shaped (16,) - stacked, (8, 16). Dim 1 is thus sample_dim. Also see examples/preprocessing/timeseries.py.

>>> data.shape == (8, 16)  # (num_batches, samples)
>>> numpy2D_to_csv(data, "data.csv", batch_size=32, batch_dim=1)
>>> # if it was (samples, num_batches), sample_dim would be 0.
... # This will make a DataFrame of 4 columns, 32 rows per column,
... # reshaping `data` to correctly concatenate samples from dim0 to dim1.

deeptrain.train_generator module

class deeptrain.train_generator.TrainGenerator(model, datagen, val_datagen, epochs=1, logs_dir=None, best_models_dir=None, loadpath=None, callbacks=None, fit_fn='train_on_batch', eval_fn='evaluate', key_metric='loss', key_metric_fn=None, val_metrics=None, custom_metrics=None, input_as_labels=False, max_is_best=None, val_freq={'epoch': 1}, plot_history_freq={'epoch': 1}, unique_checkpoint_freq={'epoch': 1}, temp_checkpoint_freq=None, class_weights=None, val_class_weights=None, reset_statefuls=False, iter_verbosity=1, logdir=None, optimizer_save_configs=None, optimizer_load_configs=None, plot_configs=None, model_configs=None, **kwargs)

Bases: deeptrain.util._traingen_utils.TraingenUtils

The central DeepTrain class. Interfaces training, validation, checkpointing, data loading, and progress tracking.

Arguments:
model: models.Model / models.Sequential [keras / tf.keras]
Compiled model to train.
datagen: DataGenerator
Train data generator; fetches inputs and labels, handles preprocessing, shuffling, stateful formats, and informing TrainGenerator when a dataset is exhausted (epoch end).
val_datagen: DataGenerator
Validation data generator.
epochs: int
Number of train epochs.
logs_dir: str / None
Path to directory where to generate log directories, that include TrainGenerator state, state report, model data, and others; see checkpoint(). If None, will not checkpoint - but model saving still possible via best_models_dir.
best_models_dir: str / None
Path to directory where to save best model. “Best” means having new highest (max_is_best==True) or lowest (max_is_best==False) entry in key_metric_history. See _save_best_model().
loadpath: str / None
Path to .h5 file containing TrainGenerator state to load (postfixed '__state.h5' by default). See load().
callbacks: dict[str: function] / TraingenCallback / None
Functions to apply at various stages, including training, validation, saving, loading, and __init__. See TraingenCallback.
fit_fn: str / function(x, y, sample_weight)
Function, or name of model method to feed data to during training; if str, will define fit_fn = getattr(model, 'fit') (example). If function, its name (substring) must include 'fit' or 'train' (currently both function identically).
eval_fn: str / function(x, y, sample_weight)

Function, or name of model method to feed data to during validation; if str, will define eval_fn = getattr(model, 'evaluate') (example). If function, its name (substring) must include 'evaluate' or 'predict':

  • 'evaluate': eval_fn uses data & labels to return metrics.
  • 'predict': eval_fn uses data to return predictions, which are used internally to compute metrics.
key_metric: str
Name of metric to track for saving best model; will store in key_metric_history. See _save_best_model().
key_metric_fn: function / None
Custom function to compute key metric; overrides key_metric if not None.
val_metrics: list[str] / None

Names of metrics to track during validation.

  • If 'predict' is not in eval_fn.__name__, is overridden by model metrics (model.compile(metrics=...))
  • If 'loss' is not included, will prepend.
  • If '*' is included, will insert model metrics at its position and pop '*'. Ex: [*, 'f1_score'] -> ['loss', 'accuracy', 'f1_score'].
custom_metrics: dict[str: function]

Name-function pairs of custom functions to use for gathering metrics. Functions must obey (y_true, y_pred) input signature for first two arguments. They may additionally supply sample_weight and pred_threshold, which will be detected and used automatically.

  • Note: if using a custom metric in model.compile(loss=tf_fn), name in custom_metrics must be function’s code name, i.e. {tf_fn.__name__: fn} (where fn is e.g. numpy version).
input_as_labels: bool
Feed model input also to its output. Ex: autoencoders.
max_is_best: bool
Whether to consider greater key_metric as better in saving best model. See _save_best_model(). If None, defaults to False if key_metric=='loss', else True.
val_freq: None / dict[str: int], str in {‘iter’, ‘batch’, ‘epoch’, ‘val’}
How frequently to validate. {'epoch': 1} -> every epoch; {'iter': 24} -> every 24 train iterations. Only one key-value pair supported. If None, won’t validate.
plot_history_freq: None / dict[str: int], str in {‘iter’, ‘batch’, ‘epoch’, ‘val’}
How frequently to plot train & validation history. Only one key-value pair supported. If None, won’t plot history.
unique_checkpoint_freq: None / dict[str: int], str in {‘iter’, ‘batch’, ‘epoch’, ‘val’}
How frequently to make checkpoints with unique savefile names, as opposed to temporary ones which are overwritten each time. Only one key-value pair supported. If None, won’t make unique checkpoints.
temp_checkpoint_freq: None / dict[str: int], str in {‘iter’, ‘batch’, ‘epoch’, ‘val’}
How frequently to make checkpoints with the same predefined name, to be overwritten (“temporary”); serves as an intermediate checkpoint to unique ones, if needed. Only one key-value pair supported. If None, won’t make temporary checkpoints.
class_weights: dict[int: int] / None

Integer-mapping of class labels to their “weights”; if not None, will feed sample_weight mediated by the weights to train function (fit_fn).

>>> class_weights = {0: 4, 1: 1}
>>> labels        == [1, 1, 0, 1]  # if
>>> sample_weight == [4, 4, 1, 4]  # then
val_class_weights: dict[int: int] / None
class_weights for validation function (eval_fn).
reset_statefuls: bool
Whether to call model.reset_states() at the end of every batch (train and val).
iter_verbosity: int in {0, 1, 2}
  • 0: print no iteration info
  • 1: print name of set being fit / validated, metric names and values,
    and model.reset_states() being called
  • 2: print a '.' at every iteration (useful if having multiple iterations per batch)
logdir: str / None
Directory where to write logs to (see logs_dir). Use to specify an existing directory (to, for example, resume training and logging in original folder). Overrides logs_dir.
optimizer_save_configs: dict / None
Dict specifying which optimizer attributes to include or exclude when saving. See save().
optimizer_load_configs: dict / None
Dict specifying which optimizer attributes to include or exclude when loading. See load().
plot_configs: dict / None
Dict specifying get_history_fig() behavior. See _DEFAULT_PLOT_CFG, and _make_plot_configs_from_metrics().
model_configs: dict / None

Dict specifying model information. Intended usage is: create model according to the dict, specifying hyperparameters, loss function, etc.

>>> def make_model(batch_shape, units, optimizer, loss):
...     ipt = Input(batch_shape=batch_shape)
...     out = Dense(units)(ipt)
...     model = Model(ipt, out)
...     model.compile(optimizer, loss)
...     return model
...
>>> model_configs = {'batch_shape': (32, 16), 'units': 8,
...                  'optimizer': 'adam', 'loss': 'mse'}
>>> model = make_model(**model_configs)

Checkpoints will include an image report with the entire dict; the larger the portion of the model that’s created according to model_configs, the more will be documented for easy reference.

kwargs: keyword arguments.
See _DEFAULT_TRAINGEN_CFG. kwargs and all other arguments are subject to validation and correction by _validate_traingen_configs().

__init__:

Instantiation. (“+” == if certain conditions are met)

  • Pulls methods from TraingenUtils
  • Validates args & kwargs, and tries to correct, printing a”NOTE” or “WARNING” message where appropriate
  • +Instantiates logging directory
  • +Loads TrainGenerator, datagen, and val_datagen states
  • +Loads model and optimizer weights (but not model architecture)
  • +Preloads train & validation data (before a call to train() is made).
  • +Applies initial callbacks
  • +Logs initial state (_log_init_state())
  • Captures and saves all arguments passed to __init__
  • Instantiates misc internal parameters to predefiend values (may be overridden by loading).
train()

The train loop.

  • Fetches data from get_data
  • Fits data via fin_fn
  • Processes fit metrics in _train_postiter_processing
  • Stores metrics in history
  • Applies 'train:iter', 'train:batch', and 'train:epoch' callbacks
  • Calls validate when appropriate

Interruption:

  • Safe: during get_data, which can be called indefinitely without changing any attributes.
  • Avoid: during _train_postiter_processing, where fit_fn is applied and weights are updated - but metrics aren’t stored, and _train_postiter_processed=False, restarting the loop without recording progress.
  • Best bet is during validate(), as get_data may be too brief.
validate(record_progress=True, clear_cache=True, restart=False, use_callbacks=True)

Validation loop.

  • Fetches data from get_data
  • Applies function based on _eval_fn_name
  • Processes and caches metrics/predictions in _val_postiter_processing
  • Applies 'val:iter', 'val:batch', and 'val:epoch' callbacks
  • Calls _on_val_end at end of validation to compute metrics and store them in val_history
  • Applies 'val_end' and maybe ('val_end': 'train:epoch') callbacks
  • If restart, calls reset_validation().
Arguments:
record_progress: bool
If False, won’t update val_history, _val_iters, _batches_validated.
clear_cache: bool
If False, won’t call clear_cache(); useful for keeping preds & labels acquired during validation.
restart: bool
If True, will call reset_valiation() before validation loop to reset validation attributes; useful for starting afresh (e.g. if interrupted).
use_callbacks: bool
If False, won’t call apply_callbacks() or plot_history().

Interruption:

  • Safe: during get_data, which can be called indefinitely without changing any attributes.
  • Avoid: during _val_postiter_processing. Model remains unaffected*, but caches are updated; a restart may yield duplicate appending, which will error or yield inaccuracies. (* forward pass may consume random seed if random ops are used)
  • In practice: prefer interrupting immediately after _print_iter_progress executes.
_train_postiter_processing(metrics)

Procedures done after every train iteration. Similar to _val_postiter_processing(), except operating on train rather than val variables, and calling validate() when appropriate.

_val_postiter_processing(record_progress=True, use_callbacks=True, metrics=None, batch_size=None)

Procedures done after every validation iteration. Unless marked “always”, are conditional and may skip.

  • Update temp val history (always)
  • Update val_datagen state (always)
  • Update val cache (preds, labels, etc)
  • Update val history
  • Reset statefuls
  • Print progress
  • Apply callbacks

Executes internal “callbacks” when appropriate: _on_iter_end, _on_batch_end, _on_epoch_end. List not exhaustive.

_on_val_end(record_progress, use_callbacks, clear_cache)

Procedures done after validate(). Unless marked “always”, are conditional and may skip. List not exhaustive.

  • Update train/val history
  • Clear cache
  • Plot history
  • Checkpoint
  • Apply callbacks
  • Check model health
  • Validate batch_size (always)
  • Reset validation flags: _inferred_batch_size, _val_loop_done, _train_loop_done (always)
get_data(val=False)

Get train (val=False) or validation (val=True) data from datagen or val_datagen, respectively. See DataGenerator.

DataGenerator.get() returns x, labels; if input_as_data == True, sets y = x - else, y = labels. Either way, sets class_labels = labels. Generates sample_weight from class_labels.

clear_cache(reset_val_flags=False)

Call to reset cache attributes accumulated during validation; useful for “restarting” validation (before calling validate()).

Attributes set to []: {'_preds_cache', '_labels_cache', '_sw_cache', '_class_labels_cach', '_set_name_cache', '_val_set_name_cach', '_y_true', '_val_sw'}.

reset_validation()

Used to restart validation (e.g. in case interrupted); calls clear_cache() and DataGenerator.reset_state() (and, if reset_statefuls, model.reset_states()).

Does not reset validation counters (e.g. _val_iters).

_should_do(freq_config, forced=False)

Checks whether a counter meets a frequency as specified in val_freq, plot_history_freq, unique_checkpoint_freq, temp_checkpoint_freq.

“Counter” is one of _fit_iters, _batches_fit, epoch, and _times_validated. Ex: with unique_checkpoint_freq = {'batch': 5}, checkpoint() will make a unique checkpoint on every 5th batch fitted during train().

_update_val_iter_cache()

Called by _on_iter_end within _val_postiter_processing(); updates validation cache variables (_labels_cache, _preds_cache, _class_labels_cache, _sw_cache).

If val_datagen has a non-None slice_idx, will preserve batch-slice structure:

>>> [[y00, y01, y02], [y10, y11, y12]]    # 2 batches, 3 slices/batch
>>> [[y00, y01], [y10, y11], [y20, y21]]  # 3 batches, 2 slices/batch
_print_train_progress()

Called within _train_postiter_processing(), by on_batch_end().

_print_val_progress()

Called within _val_postiter_processing(), by on_batch_end().

_print_progress(metrics, endchar='\n')

Called by _print_train_progress() and _print_val_progress().

_print_iter_progress(val=False)

Called within train() and validate().

plot_history(update_fig=True, w=1, h=1)

Plots train & validation history (from history and val_history).

  • update_fig=True -> store latest fig in _history_fig.
  • w & h scale the width & height, respectively, of the figure.
  • Plots configured by plot_configs.
_apply_callbacks(stage)

Callbacks. See examples/callbacks

Two approaches:
  1. Class-based: inherit deeptrain.callbacks.TraingenCallback, define stage-based methods, e.g. on_train_epoch_end. Methods also take stage argument for further control, e.g. to only call on_train_epoch_end when stage == ('val_end', 'train:epoch').

  2. Function-based: make a dict of stage-function call pairs, e.g.:

    >>> {'train:epoch': (fn1, fn2),
    ... 'val_batch': fn3,
    ... ('val_end': 'train:epoch'): fn4}
    

    Callback will execute if a key is in the stage passed to _apply_callbacks; e.g. (fn1, fn2) will execute on stage==('val_end', 'train:epoch'), with key 'train:epoch', but fn4 won’t execute, on stage=='train:epoch'.

_init_callbacks()

Instantiates callback objects (must subclass TraingenCallback), passing in TrainGenerator instance as first (and only) argument. Enables custom callbacks utilizing TrainGenerator attributes and methods.

check_health(dead_threshold=1e-07, dead_notify_above_frac=0.001, large_threshold=3, large_notify_above_frac=0.001, notify_detected_only=True)

Check whether any layer weights have ‘zeros’ or NaN weights; very fast / inexpensive.

Arguments:
dead_threshold: float
Count values below this as zeros.
dead_notify_above_frac: float
If fraction of values exceeds this, print it and the weight’s name.
notify_detected_only: bool
True -> print only if dead/NaN found False -> print a ‘not found’ message
destroy(confirm=False, verbose=1)

Class ‘destructor’. Sets own, datagen’s, and val_datagen’s attributes to [] (which can free memory of arrays), then deletes them. Also deletes ‘model’ attribute, but this has no effect on memory allocation until it’s dereferenced globally and the TensorFlow/Keras graph is cleared (best bet is to restart the Python kernel).

fit_fn
eval_fn
epoch
val_epoch
_prepare_initial_data(from_load=False)

Preloads first batch for training and validation, and superbatch if available.

_init_logger()

Instantiate log directory for checkpointing. If logdir was provided at __init__, will use it - else, will make a directory and assign its absolute path to logdir.

_init_and_validate_kwargs(kwargs)

Sets and validates **kwargs, raising exception if kwargs result in an invalid configuration, or correcting them (and possibly notifying) when possible. Also catches unused arguments.

_init_class_vars()

Instantiates various internal attributes. Most of these are saved and loaded by default.

deeptrain.visuals module

deeptrain.visuals.binary_preds_per_iteration(_labels_cache, _preds_cache, w=1, h=1)

Plots binary preds vs. labels in a heatmap, separated by batches, grouped by slices.

To be used with sigmoid outputs (1 unit). Both inputs are to be shaped (batches, slices, samples, *).

Arguments:
_labels_cache: list[np.ndarray]
List of labels cached during training/validation; insertion order must match that of _preds_cache (i.e., first array should correspond to labels of same batch as predictions in _preds_cache).
_preds_cache: list[np.ndarray]
List of predictions cached during training/validation; see docs on _labels_cache.
w, h: float
Scale figure width & height, respectively.
deeptrain.visuals.binary_preds_distribution(_labels_cache, _preds_cache, pred_th, w=1, h=1)

Plots binary preds in a scatter plot, labeling dots according to their labels, and showing pred_th as a vertical line. Positive class (1) is labeled red, negative (0) blue; a red dot far left (close to 0) is hence a strongly misclassified positive class, and vice versa.

To be used with sigmoid outputs (1 unit).

Arguments:
_labels_cache: list[np.ndarray]
List of labels cached during training/validation; insertion order must match that of _preds_cache (i.e., first array should correspond to labels of same batch as predictions in _preds_cache).
_preds_cache: list[np.ndarray]
List of predictions cached during training/validation; see docs on _labels_cache.
pred_th: float
Predict threshold (e.g. 0.5), plotted as a vertical line.
w, h: float
Scale figure width & height, respectively.
deeptrain.visuals.infer_train_hist(model, input_data, layer=None, keep_borders=True, bins=100, xlims=None, fontsize=14, vline=None, w=1, h=1)

Histograms of flattened layer output values in inference and train mode. Useful for comparing effect of layers that behave differently in train vs infrence modes (Dropout, BatchNormalization, etc) on model prediction (or intermediate activations).

Arguments:
model: models.Model / models.Sequential (keras / tf.keras)
The model.
input_data: np.ndarray / list[np.ndarray]
Data to feed to model to fetch outputs. List of arrays for multi-input networks.
layer: layers.Layer / None
Layer whose outputs to fetch; defaults to last layer (output).
keep_borders: bool
Whether to keep the plots’ bounding box.
bins: int
Number of histogram bins; kwarg to plt.hist()
xlims: tuple[float, float] / None
Histogram x limits. Defaults to min/max of flattened data per plot.
fontsize: int
Title font size.
vline: float / None
x-coordinate of vertical line to draw (e.g. predict threshold).
w, h: float
Scale figure width & height, respectively.
deeptrain.visuals.layer_hists(model, _id='*', mode='weights', input_data=None, labels=None, omit_names='bias', share_xy=(0, 0), configs=None, **kw)

Histogram grid of layer weights, outputs, or gradients.

Arguments:
model: models.Model / models.Sequential (keras / tf.keras)
The model.
_id: str / int / list[str/int].
  • int -> idx; str -> name
  • idx: int. Index of layer to fetch, via model.layers[idx].
  • name: str. Name of layer (full or substring) to be fetched. Returns earliest match if multiple found.
  • list[str/int] -> treat each str element as name, int as idx. Ex: ['gru', 2] gets (e.g.) weights of first layer with name substring ‘gru’, then of layer w/ idx 2.
  • '*' (wildcard) -> get (e.g.) outputs of all layers (except input) with ‘output’ attribute.
mode: str in (‘weights’, ‘outputs’, ‘gradients:weights’, ‘gradients:outputs’)
Whether to fetch layer weights, outputs, or gradients (w.r.t. outputs or weights).
input_data: np.ndarray / list[np.ndarray] / None
Data to feed to model to fetch outputs / gradients. List of arrays for multi-input networks. Ignored for mode='weights'.
labels: np.ndarray / list[np.ndarray]
Labels to feed to model to fetch gradients. List of arrays for multi-output networks. Ignored for mode in ('weights', 'outputs').
omit_names: str / list[str] / tuple[str]
Names of weights to omit for _id specifying layer names. E.g. for Dense, omit_names='bias' will fetch only kernel weights. Ignored for mode != 'weights'.
share_xy: tuple[bool, bool] / tuple[str, str]
Whether to share x or y limits in histogram grid, respectively. kwarg to plt.subplots(); can be 'col' or 'row' for sharing along rows or columns, respectively.
configs: dict / None

kwargs to customize various plot schemes:

  • 'plot': passed partly to ax.hist() in see_rnn.hist_clipped(); include peaks_to_clip to adjust ylims with a number of peaks disregarded. See help(see_rnn.hist_clipped). ax = subplots axis
  • 'subplot': passed to plt.subplots()
  • 'title': passed to fig.suptitle(); fig = subplots figure
  • 'tight': passed to fig.subplots_adjust()
  • 'annot': passed to ax.annotate()
  • 'save': passed to fig.savefig() if savepath is not None.
kw: dict / kwargs
kwargs passed to see_rnn.features_hist.
deeptrain.visuals.viz_roc_auc(y_true, y_pred)

Plots the Receiver Operator Characteristic curve.

deeptrain.visuals.get_history_fig(self, plot_configs=None, w=1, h=1)

Plots train / validation history according to plot_configs.

Arguments:
plot_configs: dict / None
See _DEFAULT_PLOT_CFG. If None, defaults to TrainGenerator.plot_configs (which itself defaults to _PLOT_CFG in configs.py).
w, h: float
Scale figure width & height, respectively.

plot_configs is structured as follows:

>>> {'fig_kw': fig_kw,
...  '0': {reserved_name: value,
...        plt_kw: value},
...  '1': {reserved_name: value,
...        plt_kw: value},
...  ...}
  • fig_kw: dict, passed to plt.subplots(**fig_kw)
  • reserved_name: str, one of ('metrics', 'x_ticks', 'vhlines', 'mark_best_cfg', 'ylims', 'legend_kw'). Used to configure supported custom plot behavior (see “Builtin plot customs” below).
  • plt_kw: str, name of kwarg to pass directly to plt.plot().
  • value: depends on key; see default plot_configs in _DEFAULT_PLOT_CFG and misc._make_plot_configs_from_metrics().

Only 'metrics' and 'x_ticks' keys are required for each dict - others have default values.

Builtin plot customs: (reserved_name)

  • 'metrics' (required): names of metrics to plot from histories, as {'train': train_metrics, 'val': val_metrics} (at least one metric name required, for only one of train/val - need to have “something” to plot).
  • x_ticks' (required): x-coordinates of respective metrics, of same len().
  • 'vhlines': dict[‘v’ / ‘h’: float]. vertical/horizontal lines; e.g. {'v': 10} will draw a vertical line at x = 10, and {'h': .5} at y = .5.
  • 'mark_best_cfg': {'train': metric_name} or {'val': metric_name} and (optional) {'max_is_best: bool} pairs. Will mark plot to indicate a metric optimum (max (if 'max_is_best', the default) or min).
  • 'ylims': y-limits of plot panes.
  • 'legend_kw': passed to plt.legend(); if None, no legend is drawn.

Defaults handling:

Keys and subkeys, where absent, will be filled from configs returned by misc._make_plot_configs_from_metrics().

  • If plot pane '0' is lacking entirely, it’ll be copied from the defaults.
  • If subkey 'color' in dict with key '0' is missing, will fill from defaults['0']['color'].

Further info:

  • Every key’s iterable value (list, etc) must be of same len as number of metrics in 'metrics'; this is ensured within cfg_fn.
  • Metrics are plotted in order of insertion (at both dict and list level), so later metrics will carry over to additional plot panes if number of metrics exceeds plot_first_pane_max_vals; see cfg_fn.
  • A convenient option is to change _PLOT_CFG in configs.py and pass plot_configs=None to TrainGenerator.__init__; will internally call cfg_fn, which validates some configs and tries to fill what’s missing.
  • Above, cfg_fn == misc._make_plot_configs_from_metrics()
deeptrain.visuals._plot_metrics(x_ticks, metrics, plot_kw, mark_best_idx=None, max_is_best=True, axis=None, vhlines=None, ylims=(0, 2), legend_kw=None, key_metric='loss', metric_name_to_alias_fn=None)

Plots metrics according to inputs passed by get_history_fig().

Module contents

deeptrain.scalefig(fig)

Used internally to scale figures according to env var ‘SCALEFIG’.

os.environ[‘SCALEFIG’] can be an int, float, tuple, list, or bracketless tuple, but must be a string: ‘1’, ‘1.1’, ‘(1, 1.1)’, ‘1,1.1’.

deeptrain.set_seeds(seeds=None, reset_graph=False, verbose=1)

Sets random seeds and maybe clears keras session and resets TensorFlow default graph.

NOTE: after calling w/ reset_graph=True, best practice is to re-instantiate model, else some internal operations may fail due to elements from different graphs interacting (of pre-reset model and post-reset tensors).