import streamlit as st
from sklearn import datasets
import numpy as np
import pandas as pd
st.title('Iris Flower Prediction App')
st.sidebar.title('Input Parameters')
@st.cache
def get_data():
print('Load data ...')
iris = datasets.load_iris()
return iris.data[:, :2], iris.target
@st.cache(allow_output_mutation=True)
def train_model(train_X, train_Y):
print('Train model ...')
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier()
clf.fit(train_X, train_Y)
return clf
def get_features():
sepal_length = st.sidebar.slider('Sepal length', 4.3, 7.9, 5.4)
sepal_width = st.sidebar.slider('Sepal width', 2.0, 4.4, 3.4)
data = {'sepal_length': sepal_length,
'sepal_width': sepal_width}
features = pd.DataFrame(data, index=[0])
return features
X, Y = get_data()
clf = train_model(X, Y)
df = get_features()
prediction = clf.predict(df)
def display_image(value):
from PIL import Image
if value == 0:
image = Image.open('Iris-setosa.jpg')
st.image(image, caption='Iris Setosa', width=200)
elif value == 1:
image = Image.open('IRIS_VERSICOLOR.jpeg')
st.image(image, caption='Iris Versicolor', width=200)
else:
image = Image.open('Iris_virginica.jpg')
st.image(image, caption='Iris Virginica', width=200)
display_image(prediction[0])