### Exercise: Plant Disease App

We want to build a plant disease app that can detect wether a plant is healthy or has a disease, given a photo of its leave üçÉ. These are a few examples for each condition:

| Healthy  | Angular Leaf Spot  | Bean Rust |
|---|---|---|
| ![](exercise_3/healthy/healthy_train.0.jpg)  | ![](exercise_3/angular_leaf_spot/angular_leaf_spot_train.215.jpg)  | ![](exercise_3/bean_rust/bean_rust_train.214.jpg)  |

All the images are located in the `exercise_2` folder, already splited into the three different conditions.

**a)** Discuss, in your own words, how would you approach this problem. 

That is, what type of learning problem is, and the kind of models you can use to solve it.

**b)** Implement a solution based on what you previously described. 
**Important: you must evaluate your model over a test set, and report the appropriate metrics for your model.**

As a starter, you can use the following code to get all the file_paths and the labels for each image:

In [1]:
import pandas as pd
import glob

In [2]:
image_paths = glob.glob('exercise_3/*/*')

labels = []

for image_path in image_paths:
    labels.append(image_path.split('/')[1].split('/')[0])

In [3]:
import torch
from torchvision.io import read_image, ImageReadMode
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torchvision.models.feature_extraction import create_feature_extractor

# Step 1: Initialize model with the best available weights
weights = EfficientNet_B0_Weights.DEFAULT
model = efficientnet_b0(weights=weights)
model.eval()  # disables gradient calculation for inference.

# Step 2: Initialize the preprocess function
preprocess = weights.transforms()


layer_before_final_classifiers = 'flatten'

return_nodes = {
    layer_before_final_classifiers: layer_before_final_classifiers
}

feature_extractor = create_feature_extractor(model, return_nodes=return_nodes)

In [4]:
cnn_codes = []
for i, image_name in enumerate(image_paths):
    print(f'Processing image {i+1}/{len(image_paths)}')
    with torch.no_grad():
        img = read_image(image_name, mode=ImageReadMode.RGB)
        img_processed = preprocess(img).unsqueeze(0)
        cnn_code = feature_extractor(img_processed)[layer_before_final_classifiers]
    cnn_codes.append(cnn_code)


cnn_codes = torch.cat(cnn_codes)

Processing image 1/1034
Processing image 2/1034
Processing image 3/1034
Processing image 4/1034
Processing image 5/1034
Processing image 6/1034
Processing image 7/1034
Processing image 8/1034
Processing image 9/1034
Processing image 10/1034
Processing image 11/1034
Processing image 12/1034
Processing image 13/1034
Processing image 14/1034
Processing image 15/1034
Processing image 16/1034
Processing image 17/1034
Processing image 18/1034
Processing image 19/1034
Processing image 20/1034
Processing image 21/1034
Processing image 22/1034
Processing image 23/1034
Processing image 24/1034
Processing image 25/1034
Processing image 26/1034
Processing image 27/1034
Processing image 28/1034
Processing image 29/1034
Processing image 30/1034
Processing image 31/1034
Processing image 32/1034
Processing image 33/1034
Processing image 34/1034
Processing image 35/1034
Processing image 36/1034
Processing image 37/1034
Processing image 38/1034
Processing image 39/1034
Processing image 40/1034
Processin

In [5]:
cnn_codes = cnn_codes.numpy()

In [6]:
assert len(cnn_codes) == len(labels)

In [7]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(cnn_codes, labels, test_size=0.2, random_state=0)

In [8]:
from sklearn.linear_model import LogisticRegression

clf = LogisticRegression(random_state=0, max_iter=1000).fit(X_train, y_train)

y_pred = clf.predict(X_test)

from sklearn.metrics import classification_report

print(classification_report(y_test, y_pred))

                   precision    recall  f1-score   support

angular_leaf_spot       0.86      0.82      0.84        73
        bean_rust       0.86      0.84      0.85        77
          healthy       0.90      0.96      0.93        57

         accuracy                           0.87       207
        macro avg       0.87      0.88      0.87       207
     weighted avg       0.87      0.87      0.87       207

