Source code for npdl.utils.data

# -*- coding: utf-8 -*-


import numpy as np


[docs]def one_hot(labels, nb_classes=None): """One-hot encoding is often used for indicating the state of a state machine. When using binary or Gray code, a decoder is needed to determine the state. A one-hot state machine, however, does not need a decoder as the state machine is in the nth state if and only if the nth bit is high. A ring counter with 15 sequentially-ordered states is an example of a state machine. A ``one-hot`` implementation would have 15 flip flops chained in series with the Q output of each flip flop connected to the D input of the next and the D input of the first flip flop connected to the Q output of the 15th flip flop. The first flip flop in the chain represents the first state, the second represents the second state, and so on to the 15th flip flop which represents the last state. Upon reset of the state machine all of the flip flops are reset to ``0`` except the first in the chain which is set to ``1``. The next clock edge arriving at the flip flops advances the one ``hot`` bit to the second flip flop. The ``hot`` bit advances in this way until the 15th state, after which the state machine returns to the first state. An address decoder converts from binary or gray code to one-hot representation. A priority encoder converts from one-hot representation to binary or gray code. In natural language processing, a one-hot vector is a :math:`1 × N` matrix (vector) used to distinguish each word in a vocabulary from every other word in the vocabulary. The vector consists of 0s in all cells with the exception of a single 1 in a cell used uniquely to identify the word. Parameters ---------- labels : iterable nb_classes : (iterable, optional) Returns ------- numpy.array Returns a one-hot numpy array. """ classes = np.unique(labels) if nb_classes is None: nb_classes = classes.size one_hot_labels = np.zeros((labels.shape[0], nb_classes)) for i, c in enumerate(classes): one_hot_labels[labels == c, i] = 1 return one_hot_labels
[docs]def unhot(one_hot_labels): """Get argmax indexes. Parameters ---------- one_hot_labels : numpy.array Returns ------- numpy.array Returns a unhot numpy array. """ return np.argmax(one_hot_labels, axis=-1)