Converting TensorFlow tutorial to work with my own data

后端 未结 1 402
独厮守ぢ
独厮守ぢ 2020-12-03 20:16

This is a follow on from my last question Converting from Pandas dataframe to TensorFlow tensor object

I\'m now on the next step and need some more help. I\'m trying

相关标签:
1条回答
  • 2020-12-03 21:14

    I stripped off the non-relevant stuff so as to preserve the formatting and indentation. Hopefully it should be clear now. The following code reads a CSV file in batches of N lines (N specified in a constant at the top). Each line contains a date (first cell), then a list of floats (480 cells) and a one-hot vector (3 cells). The code then simply prints the batches of these dates, floats, and one-hot vector as it reads them. The place where it prints them is normally where you'd actually run your model and feed these in place of the placeholder variables.

    Just keep in mind that here it reads each line as a String, and then converts the specific cells within that line into floats, simply because the first cell is easier to read as a string. If all your data is numeric, then simply set the defaults into a float/int rather than an 'a' and get rid of the code that converts strings to floats. It's not needed otherwise!

    I put some comments to clarify what it's doing. Let me know if something is unclear.

    import tensorflow as tf
    
    fileName = 'YOUR_FILE.csv'
    
    try_epochs = 1
    batch_size = 3
    
    TD = 1 # this is my date-label for each row, for internal pruposes
    TS = 480 # this is the list of features, 480 in this case
    TL = 3 # this is one-hot vector of 3 representing the label
    
    # set defaults to something (TF requires defaults for the number of cells you are going to read)
    rDefaults = [['a'] for row in range((TD+TS+TL))]
    
    # function that reads the input file, line-by-line
    def read_from_csv(filename_queue):
        reader = tf.TextLineReader(skip_header_lines=False) # i have no header file
        _, csv_row = reader.read(filename_queue) # read one line
        data = tf.decode_csv(csv_row, record_defaults=rDefaults) # use defaults for this line (in case of missing data)
        dateLbl = tf.slice(data, [0], [TD]) # first cell is my 'date-label' for internal pruposes
        features = tf.string_to_number(tf.slice(data, [TD], [TS]), tf.float32) # cells 2-480 is the list of features
        label = tf.string_to_number(tf.slice(data, [TD+TS], [TL]), tf.float32) # the remainin 3 cells is the list for one-hot label
        return dateLbl, features, label
    
    # function that packs each read line into batches of specified size
    def input_pipeline(fName, batch_size, num_epochs=None):
        filename_queue = tf.train.string_input_producer(
            [fName],
            num_epochs=num_epochs,
            shuffle=True)  # this refers to multiple files, not line items within files
        dateLbl, features, label = read_from_csv(filename_queue)
        min_after_dequeue = 10000 # min of where to start loading into memory
        capacity = min_after_dequeue + 3 * batch_size # max of how much to load into memory
        # this packs the above lines into a batch of size you specify:
        dateLbl_batch, feature_batch, label_batch = tf.train.shuffle_batch(
            [dateLbl, features, label], 
            batch_size=batch_size,
            capacity=capacity,
            min_after_dequeue=min_after_dequeue)
        return dateLbl_batch, feature_batch, label_batch
    
    # these are the date label, features, and label:
    dateLbl, features, labels = input_pipeline(fileName, batch_size, try_epochs)
    
    with tf.Session() as sess:
    
        gInit = tf.global_variables_initializer().run()
        lInit = tf.local_variables_initializer().run()
    
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
    
        try:
            while not coord.should_stop():
                # load date-label, features, and label:
                dateLbl_batch, feature_batch, label_batch = sess.run([dateLbl, features, labels])      
    
                print(dateLbl_batch);
                print(feature_batch);
                print(label_batch);
                print('----------');
    
        except tf.errors.OutOfRangeError:
            print("Done looping through the file")
    
        finally:
            coord.request_stop()
    
        coord.join(threads)
    
    0 讨论(0)
提交回复
热议问题