Day31: Saving neural networks

Posted by csiu on March 27, 2017 | with: 100daysofcode, Machine Learning

With regards to fitting the CNN into my architecture, I need to change a few things: (1) fix my input format, and (2) define the CNN and make it “play nice” with the pre-existing MLP. I also want to (3) be able to save the trained network and reload the network at a later and separate execution.

Reformat input

As mentioned yesterday, the input of a Lasagne CNN needs to be a nested data structure. Since I want to use a 1D-CNN, the format needs to be: dataset[sample][channel][features]. I define the following function to reformat my matrix into this format.

def reformat_input_matrix2convnet1d(X):
    Reformat matrix shape (samples,features)
    to nested input for 1d-cnn (samples,channel,features)
    X = np.matrix(X, dtype=theano.config.floatX)
    return(np.array([r for r in X]))

ASIDE: I want to use a 1D-CNN since I will be repurposing this code for predicting classes using epigenetic 1D sequence data.

Save model

Training the model takes time, especially with greater number of epochs, and wider & deeper models. I want to be able to train it once and then use the trained at a later time or later date. According to mahdieh@stackoverflow (2016), you can save the model parameters and the model by Pickle.

import cPickle as pickle

# Save model
net_info = {'network': network,
            'params': lasagne.layers.get_all_param_values(network)}
pickle.dump(net_info, open(model_file, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

# Load model
net = pickle.load(open(model_file, 'rb'))
all_params = net['params']
lasagne.layers.set_all_param_values(network, all_params)

NOTE: I implemented this features as a class type.

More refactoring

With the new features and addition of a CNN, I need to refactor my code. Mainly change the MLP from a class to only a function defining the network structure. This way I can specify another network structure (eg. CNN) to be able to use as default the same objective, update, and training originally available to the BasicMLP class.

NOTE: I hard coded network_type = "cnn" for now.