Source code for deepmatcher.runner

import logging
import os
import pdb
import time
import sys
import math

import torch

from . import optim as optim
from .data import MatchingIterator
from .loss import SoftNLLLoss
from tqdm import tqdm
from collections import OrderedDict

try:
    get_ipython
    from tqdm import tqdm_notebook as tqdm
except NameError:
    from tqdm import tqdm

logger = logging.getLogger(__name__)


[docs]class Statistics(object): """ Accumulator for loss statistics, inspired by ONMT. Currently calculates: * F1 * Precision * Recall * Accuracy """ def __init__(self): self.loss_sum = 0 self.examples = 0 self.tps = 0 self.tns = 0 self.fps = 0 self.fns = 0 self.start_time = time.time()
[docs] def update(self, loss=0, tps=0, tns=0, fps=0, fns=0): examples = tps + tns + fps + fns self.loss_sum += loss * examples self.tps += tps self.tns += tns self.fps += fps self.fns += fns self.examples += examples
[docs] def loss(self): return self.loss_sum / self.examples
[docs] def f1(self): prec = self.precision() recall = self.recall() return 2 * prec * recall / max(prec + recall, 1)
[docs] def precision(self): return 100 * self.tps / max(self.tps + self.fps, 1)
[docs] def recall(self): return 100 * self.tps / max(self.tps + self.fns, 1)
[docs] def accuracy(self): return 100 * (self.tps + self.tns) / self.examples
[docs] def examples_per_sec(self): return self.examples / (time.time() - self.start_time)
[docs]class Runner(object):
[docs] @staticmethod def print_stats(name, epoch, batch, n_batches, stats, cum_stats): """Write out statistics to stdout. """ print((' | {name} | [{epoch}][{batch:4d}/{n_batches}] || Loss: {loss:7.4f} |' ' F1: {f1:7.2f} | Prec: {prec:7.2f} | Rec: {rec:7.2f} ||' ' Cum. F1: {cf1:7.2f} | Cum. Prec: {cprec:7.2f} | Cum. Rec: {crec:7.2f} ||' ' Ex/s: {eps:6.1f}').format( name=name, epoch=epoch, batch=batch, n_batches=n_batches, loss=stats.loss(), f1=stats.f1(), prec=stats.precision(), rec=stats.recall(), cf1=cum_stats.f1(), cprec=cum_stats.precision(), crec=cum_stats.recall(), eps=cum_stats.examples_per_sec()))
[docs] @staticmethod def print_final_stats(epoch, runtime, datatime, stats): """Write out statistics to stdout. """ print(('Finished Epoch {epoch} || Run Time: {runtime:7f} | ' 'Load Time: {datatime:7f} | F1: {f1:7.2f} | Prec: {prec:7.2f} | ' 'Rec: {rec:7.2f} || Ex/s: {eps:6.1f}\n').format( epoch=epoch, runtime=runtime, datatime=datatime, f1=stats.f1(), prec=stats.precision(), rec=stats.recall(), eps=stats.examples_per_sec()))
[docs] @staticmethod def set_pbar_status(pbar, stats, cum_stats): postfix_dict = OrderedDict([ ('Loss', '{0:7.4f}'.format(stats.loss())), ('F1', '{0:7.2f}'.format(stats.f1())), ('Cum. F1', '{0:7.2f}'.format(cum_stats.f1())), ('Ex/s', '{0:6.1f}'.format(cum_stats.examples_per_sec())), ]) pbar.set_postfix(ordered_dict=postfix_dict)
[docs] @staticmethod def compute_scores(output, target): predictions = output.max(1)[1].data correct = (predictions == target.data).float() incorrect = (1 - correct).float() positives = (target.data == 1).float() negatives = (target.data == 0).float() tp = torch.dot(correct, positives) tn = torch.dot(correct, negatives) fp = torch.dot(incorrect, negatives) fn = torch.dot(incorrect, positives) return tp, tn, fp, fn
[docs] @staticmethod def tally_parameters(model): n_params = sum([p.nelement() for p in model.parameters() if p.requires_grad]) print('* Number of trainable parameters:', n_params)
@staticmethod def _run(run_type, model, dataset, criterion=None, optimizer=None, train=False, device=None, save_path=None, batch_size=32, num_data_workers=2, batch_callback=None, epoch_callback=None, progress_style='bar', log_freq=5, sort_in_buckets=None, return_predictions=False, **kwargs): sort_in_buckets = train run_iter = MatchingIterator( dataset, model.train_dataset, batch_size=batch_size, device=device, sort_in_buckets=sort_in_buckets) if device == 'cpu': model = model.cpu() if criterion: criterion = criterion.cpu() elif torch.cuda.is_available(): model = model.cuda() if criterion: criterion = criterion.cuda() elif device == 'gpu': raise ValueError('No GPU available.') if train: model.train() else: model.eval() # Init model init_batch = next(run_iter.__iter__()) model(init_batch) epoch = model.epoch datatime = 0 runtime = 0 cum_stats = Statistics() stats = Statistics() predictions = {} id_attr = model.train_dataset.id_field label_attr = model.train_dataset.label_field if train and epoch == 0: Runner.tally_parameters(model) epoch_str = 'epoch ' + str(epoch + 1) + ' :' print('=> ', run_type, epoch_str) batch_end = time.time() if progress_style == 'bar': pbar = tqdm(total=len(run_iter) // log_freq, bar_format='{l_bar}{bar}{postfix}', file=sys.stdout) for batch_idx, batch in enumerate(run_iter): batch_start = time.time() datatime += batch_start - batch_end output = model(batch) # from torchviz import make_dot, make_dot_from_trace # dot = make_dot(output.mean(), params=dict(model.named_parameters())) # pdb.set_trace() loss = float('NaN') if criterion: loss = criterion(output, getattr(batch, label_attr)) if label_attr: scores = Runner.compute_scores(output, getattr(batch, label_attr)) else: scores = output.data.new([0] * output.shape[0]) cum_stats.update(float(loss), *scores) stats.update(float(loss), *scores) if return_predictions: predicted = output.max(1)[1].data for idx, id in enumerate(getattr(batch, id_attr)): predictions[id] = float(output[idx, 1].exp()) if (batch_idx + 1) % log_freq == 0: if progress_style == 'log': Runner.print_stats(run_type, epoch + 1, batch_idx + 1, len(run_iter), stats, cum_stats) elif progress_style == 'bar': pbar.update() Runner.set_pbar_status(pbar, stats, cum_stats) stats = Statistics() if train: model.zero_grad() loss.backward() if not optimizer.params: optimizer.set_parameters(model.named_parameters()) optimizer.step() batch_end = time.time() runtime += batch_end - batch_start pbar.close() Runner.print_final_stats(epoch + 1, runtime, datatime, cum_stats) if return_predictions: return predictions else: return cum_stats.f1()
[docs] @staticmethod def train(model, train_dataset, validation_dataset, epochs=50, criterion=None, optimizer=None, pos_weight=1, label_smoothing=False, best_save_path=None, save_every_prefix=None, save_every_freq=None, **kwargs): model.initialize(train_dataset) model.register_train_buffer('optimizer_state') model.register_train_buffer('best_score') model.register_train_buffer('epoch') if criterion is None: assert pos_weight < 2 neg_weight = 2 - pos_weight criterion = SoftNLLLoss(label_smoothing, torch.Tensor([neg_weight, pos_weight])) optimizer = optimizer or optim.Optimizer() if model.optimizer_state is not None: model.optimizer.base_optimizer.load_state_dict(model.optimizer_state) if model.epoch is None: epochs_range = range(epochs) else: epochs_range = range(model.epoch + 1, epochs) if model.best_score is None: model.best_score = -1 optimizer.last_acc = model.best_score for epoch in epochs_range: model.epoch = epoch Runner._run( 'TRAIN', model, train_dataset, criterion, optimizer, train=True, **kwargs) score = Runner._run('EVAL', model, validation_dataset, train=False, **kwargs) optimizer.update_learning_rate(score, epoch) model.optimizer_state = optimizer.base_optimizer.state_dict() new_best_found = False if score > model.best_score: print('* Best F1:', score) model.best_score = score new_best_found = True if best_save_path and new_best_found: print('Saving best model...') model.save_state(best_save_path) if save_every_prefix is not None and (epoch + 1) % save_every_freq == 0: save_path = '{prefix}_ep{epoch}.pth'.format( prefix=save_every_prefix, epoch=epoch) model.save_state(save_path)
[docs] def eval(model, dataset, **kwargs): return Runner._run('EVAL', model, dataset, train=False, **kwargs)
[docs] def predict(model, dataset, **kwargs): return Runner._run('PREDICT', model, dataset, return_predictions=True, **kwargs)