模型过拟合时可以通过正则化降低神经网络参数的作用,可同时使用dropout来断掉一些神经元之间的连接,提高模型的泛化能力:
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
import pandas as pd
import numpy as np
from tensorflow.keras.callbacks import EarlyStopping,ModelCheckpoint
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32)/479
y = tf.cast(y, dtype=tf.int32)
return x, y
allmatrix = np.array(pd.read_csv('re5k6mer_allmatrix_com.csv', header=0, index_col=0, low_memory=True))
target = np.loadtxt('re5k6mer_target_com.txt')
print("mean: ", np.mean(allmatrix), "std: ", np.std(allmatrix))
x = tf.convert_to_tensor(allmatrix, dtype=tf.int32)
y = tf.convert_to_tensor(target, dtype=tf.int32)
idx = tf.range(186995)
idx = tf.random.shuffle(idx)
x_train, y_train = tf.gather(x, idx[:136995]), tf.gather(y, idx[:136995]))
x_val, y_val = tf.gather(x, idx[-50000:-35000]), tf.gather(y, idx[-50000:-35000])
x_test, y_test = tf.gather(x, idx[-35000:]), tf.gather(y, idx[-35000:])
batchsz = 256
db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db = db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.map(preprocess).batch(batchsz)
sample = next(iter(db))
print(sample[0].shape, sample[1].shape)
network = Sequential([layers.Dense(256, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(0.0001)),
layers.Dropout(0.3),
layers.Dense(256,activation='relu',kernel_regularizer=tf.keras.regularizers.l2(0.0001)),
layers.Dropout(0.3),
layers.Dense(128, activation='relu',kernel_regularizer=tf.keras.regularizers.l2(0.0001)),
layers.Dropout(0.2),
layers.Dense(32, activation='relu'),
layers.Dense(10, activation='relu'),
layers.Dense(1,activation='sigmoid')])
network.build(input_shape=(None, 2080))
network.summary()
early_stopping = EarlyStopping(monitor='val_acc',min_delta=0.001,patience=8)
checkpoint=ModelCheckpoint('re5k6mer_model_3.h5',monitor='val_acc',model='max',verbose=1,save_best_only=True)
network.compile(optimizer=optimizers.Adam(lr=0.001),
loss='binary_crossentropy',
metrics=['accuracy'] )
network.fit(db, epochs=100, validation_data=ds_val, validation_steps=2,callbacks=[early_stopping,checkpoint])
network.evaluate(db_test)