Reading CSV file by using Tensorflow Data API and Splitting Tensor into Training and Test Sets for LSTM
There might be times when you have your data only in a one huge CSV file and you need to feed it into Tensorflow and at the same time, you need to split it into two sets: training and testing. Using train_test_split function of Scikit-Learn cannot be proper because of using a TextLineReader of Tensorflow Data API so the data is now a tensor. Scikit-Learn works on Numpy arrays not Tensorflow’s tensors.
Here, we will show how you can do it using Tensorflow Data API efficiently. Firstly, let’s create a Pandas’s dataframe, consisting of 204 observations, 33 features and one binary labels and save it as a CSV file into our directory.
What we are also going to do here is to explain how we can transform this dataset into a structure, where we can use it in a LSTM network. LSTM accepts a structure in the shape of [batch_size, time_steps, number_of_features]. Here, batch_size will be kept dynamic. Let’s choose time_steps = 6. We already know that number of features is 33. Therefore, the rows in CSV file will represent time-steps and every 6 time-steps will denote one person (one customer, one observation, et cetera). Total number of observations is then 204/6 = 34.
Let’s split the dataset into two and use $80\%$ of it for training and the rest for testing. Therefore, we will have 27 observations (162 rows) for training and 7 observations (42 rows) for testing.
input_fn() function will decode the CSV. First, it combines 6 lines into a single observation, then reads 33 float columns of features and one integer column of label. It also one-hot-encodes the label variable.
Let’s choose batch_size = 3 which means that every batch will consists of 3 observations (18 rows). Therefore, training set will have 9 total batches without a remainder and total number of batches for testing set will be 3 with the last batch only having one observation (6 rows).
As can be seen easily, here, we are using .take() and .skip() function of Tensorflow data API.
Additionally, we use Reinitializable Iterator here so then we switch dynamically between different input data streams. We create an iterator for different datasets. Note that all the datasets must have the same datatype and shape. Do not also forget that iterator has to be initialized before it starts running.