import numpy as np
import tensorflow as tf
from tensorflow import keras
import pandas as pd
 
# DATA INGESTION
# -------------------------------------------------------------------------------------
# Ingest the Training, Validation and Test datasets from image folders.
# The labels (edges vs. noedges) are automatically inferred based on folder names.
 
train_ds = keras.utils.image_dataset_from_directory(
  directory='./data/train',
  labels='inferred',
  label_mode='categorical',
  validation_split=0.2,
  subset='training',
  image_size=(343, 343),
  seed=123,
  batch_size=8)
 
val_ds = keras.utils.image_dataset_from_directory(
  directory='./data/train',
  labels='inferred',
  label_mode='categorical',
  validation_split=0.2,
  subset='validation',
  image_size=(343, 343),
  seed=123,
  batch_size=8)
 
test_ds = keras.utils.image_dataset_from_directory(
  directory='./data/test',
  labels='inferred',
  label_mode='categorical',
  image_size=(343, 343),
  shuffle = False,
  batch_size=1)

# TRAINING
# -------------------------------------------------------------------------------------
# Build, compile and fit the Deep Learning model

model = keras.applications.Xception(
   weights=None, input_shape=(343, 343, 3), classes=2)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics='accuracy')
model.fit(train_ds, epochs=5, validation_data=val_ds)

# to reuse the model later:
model.save('model_save.h5')
#model = keras.models.load_model('model_save.h5')

# GENERATE PREDICTIONS on previously unseen data
# -------------------------------------------------------------------------------------
predictions_raw = model.predict(test_ds)

# SUMMARIZE RESULTS (convenience, alterantive approaches are available)
# -------------------------------------------------------------------------------------

# initialize pandas dataframe with file paths
df = pd.DataFrame(data = {"file_path": test_ds.file_paths})
class_names = test_ds.class_names # ["edges","noedges"]

# add actual labels column
def get_actual_label(label_vector):
    for index, value in enumerate(label_vector):
      if (value == 1):
          return class_names[index]
actual_label_vectors = np.concatenate([label for data, label in test_ds], axis=0).tolist() # returns array of label vectors [[0,1],[1,0],...] representing category (edges/noedges)
actual_labels =  list(map(lambda alv: get_actual_label(alv), actual_label_vectors))
df["actual_label"] = actual_labels


# add predicted labels column
predictions_binary = np.argmax(predictions_raw, axis=1) # flatten to 0-1 recommendation
predictions_labeled = list(map(lambda x: class_names[0] if x == 0 else class_names[1],list(predictions_binary)))

df["predicted_label"] = predictions_labeled


df.to_csv("results.csv", index=False)
