Data Science project
Develop an Interactive Drawing Recognition App based on CNN — Deploy it with Flask
A quick and easy tutorial about an essential technology for your Data Science projects.
Apr 26 ·9min read
Building Machine Learning models is a common task in which you probably feel comfortable with. However, once the model you’ve trained and tested offline satisfies you, what should you do with it ? How would you present it to your non-technical boss or client ? How would you deploy it online so other people can use it ?
In this article, I will try to tackle those questions that are usually not detailed at school although being an important — the most important — part of your Data Science projects.
To do so, I have decided to take as example a Drawing application that uses a Convolutional Neural Network Model to classify drawings made by the user.
Here is the workflow :
I’ll first introduce the model and then describe the app developpement. After reading through, I hope that this article will be of any help for your future Data Science projects’ deployment !
All the following code can be find on my github .
CNN Model
The first part of this project is to prepare the data and build our model !
I have decided to use data from the ‘Quick, draw!’ game where users need to draw as quickly as possible an arbitrary object. The dataset is available here .
I focus my use case on 6 animals : Cat, Giraffe, Sheep, Bat, Octopus and Camel, making the task a Multiclass Classification . Here is a sample of the data :
Preprocessing
Lucky for us, images from this dataset were already preprocessed to a uniform 28*28 pixel image size. Here are the next steps :
- We need to combine our data so we can use it for training and testing. I only use 10 000 samples for this model.
- We then can split the features and labels (X and y).
- Finally, we split data between train and test, following the usual ( 80–20 ) ratio. We also normalize values between 0 and 1 (X/255) as pixels of a grayscale image lie between 0 and 255.
Architecture
Once everything is ready, let’s build our model using Keras ! This model will have the following structure :
- Convolutional Layer : 30 filters, (3 * 3) kernel size
- Max Pooling Layer : (2 * 2) pool size
- Convolutional Layer : 15 filters, (3 * 3) kernel size
- Max Pooling Layer : (2 * 2) pool size
- DropOut Layer : Dropping 20% of neurons.
- Flatten Layer
- Dense/Fully Connected Layer : 128 Neurons, Relu activation function
- Dense/Fully Connected Layer : 50 Neurons, Softmax activation function
Here is the corresponding code :
Now that our model is ready, we just need to train it and evaluate its performances.
Our classifier reaches 92.7% accuracy after 15 epochs, which is enough for our recognition app ! Let’s check the confusion matrix.
As we can see, most of the drawings were well classified. However, some classes seem to be harder to differentiate than others : Cat with Bat or Camel with Sheep for example. This can be explained by similarities in their shapes !
Here are some images that were misclassified by our model. Most of those images could have been easily mistaken, even by a human eye ! Don’t forget that our dataset gathers handmade human drawings playing the ‘Quick, Draw!’ game. Thus, many images can be irrelevant for a class.
Saving the model
Now that our model is ready, we would like to embbed it into a Flask Web-App . To do so, it is more convenient to save (serialize) our model using pickle .
Note : You could directly train your model into flask, but it would be really time consuming and not user friendly.
import pickle
with open('model_cnn.pkl', 'wb') as file:
pickle.dump(model_cnn, file)
Developing our Drawing Web-App with Flask
Flask
Flask is a web micro-framework written in Python. It allows you to design a solid and professional web application.
How does it work ?
Although it doesn’t require a specific architecture, there are some good practice to follow :
- app.py : Is the main code that will run our Flask application. It will contain the different routes for our application, respond to HTTP requests and decide what to display in the templates. In our case, it will also call our CNN classifier , operate pre-processing steps for our input data and make prediction .
- Templates folder : A template is an HTML file which can receive Python objects and is linked to the Flask application. Hence, our html pages will be stored in this folder.
- Static folder : Style sheets, scripts, images and other elements that will never be generated dynamically must be stored in this folder. We will place our Javascript and CSS files in it.
This project will require :
- Two static files : draw.js and styles_draw.css
- Two template files : draw.html and results.html .
- Our main file : app.py
- Our model_cnn.plk saved earlier.
Let’s now build our app !
Get the user Input
The second part of this project is to get the user input : a drawing that will be classified by our trained model. To do so, we will first design the drawing area using javascript and HTML5 . I will not introduce the styles_draw.css in this article but you can find it on my github .
draw.html
- We import our css and js files located in the static folder using
{{url_for('static',filename='styles_draws.css’)}}
and{{url_for('static',filename='draw.js’)}}
. This is the Jinja syntax to import files. - We set our drawing area with the
<canvas>
tag. - We call the drawCanvas() javascript function contained in draw.js .
- We initialize our form so we can use the
POST
method to send data to our flask instance/ app.py . -
action = “{{url_for('predict')}”
is again Jinja syntax. It specifies the path that will be used in app.py when submitting the form. - We add an extra hidden field to our form which will be used to transfer the image.
<input type = “hidden“ id =’url' name = ‘url' value = “”>
- That’s all ! Easy right ?
We now have to use javascript to make it a bit more dynamic ! Otherwise, your canvas won’t do anything …
draw.js
This Javascript code allows us to design and interact with our drawing area.
- drawCanvas() aims to initialize the canvas’ main functions (mouseUp, mouseDown, …) that will allow interactions with the user’s mouse.
- addClick() saves the cursor’s position when the user clicks on the canvas.
- redraw() clears the canvas and redraws everything each time the function is called.
After my drawing, the canvas looks as follows (this is a giraffe by the way) :
Now that our canvas is ready to get the user’s drawing, we need to ensure that the image will be able to reach our app in app.py .
Usually, we can directly use the POST
function and submit data through a form. However, submitting raw images cannot currently be done with this method. Moreover, I didn’t want the user to have to save and then upload his drawing as it would have impacted the fluidity of his experience.
A small trick that I used to overcome this issue was to encode the image in base64 before sending it through the form using the hidden input field set earlier in results.html . This encoding process will then be reversed in app.py.
- save() is called when the user clicks on the ‘predict’ button. It will send the base64 encoded image through the form.
Make predictions
Now that we are able to get the user input, it’s time to use it to make prediction ! To do so, we only need one file :
app.py
As stated earlier, app.py is our project’s main file in which the Flask application is instanciated.
Main points of this code :
1)Initializing the app and specifying the template folder. We can do that using this line of code :
app = flask.Flask(__name__, template_folder =’templates’)
2)Define the routes (only two for our app) :
- @app.route(‘/’) : Is our default path — It will return the default draw.html template.
- @app.route(‘/predict’) : Is called when clicking on the ‘predict’ button. Returns the results.html template after processing the user input.
3)The predict function will be triggered by a POST
action from the form (remember when we set this path in result.html thanks to Jinja syntax ! ). It will then proceed as follows :
- Access the base64 encoded drawing input with
request.form['url']
, where ‘url’ is the name of the hidden input field in the form which contains the encoded image. - Decode the image and set it into an array.
- Resize and reshape the image to get a 28 * 28 input for our model. We care about keeping its ratio.
- Perform the prediction using our CNN classifier.
- As
model.predict()
returns a probablity for each class in a single array, we must find the array’s highest probability and get the corresponding class using the pre-defined dictionary. - Finally return the results.html template and pass the previously made prediction as parameter :
return render_template('results.html', prediction= final_pred)
Note :
— Our base64 encoded image will look something like that : data:image/png;base64,iVBOR…
. Thus, we just need to remove the first 21 characters to get the clean url. To do so, we use the draw[init_base64:]
line of code. —
Display the results
That’s all ! Almost everything is done. The next step is to display our results.
results.html
Finally, we use results.html to display the prediction computed in app.py. We will again need to use the Jinja syntax to display the prediction .
Here is the rendering when our prediction is “Giraffe”:
Run the app
The last step is to launch our app ! You can go into your Drawing App folder (where you can find app.py at the root) and use flask run.
Your app will be running on your local server. By default it is 127.0.0.1:5000 .
Conclusion
That’s it ! In this article, we have seen how to develop a Flask Drawing app that uses a previously built CNN model to classify drawings made by the user.
This is one over many possible uses of Flask to deploy machine learning models. In fact, an infinity of use cases can be found, and I hope this specific project will help you build other ML web-apps to make your code more accessible to others !
All the code is available on my github !
Have a great day !
References
Of course, I didn’t do that all alone. Here is where I found some inspiration :
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网
猜你喜欢:本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
深入剖析Tomcat
Budi Kurniawan、Paul Deck / 曹旭东 / 机械工业出版社华章公司 / 2011-12-31 / 59.00元
本书深入剖析Tomcat 4和Tomcat 5中的每个组件,并揭示其内部工作原理。通过学习本书,你将可以自行开发Tomcat组件,或者扩展已有的组件。 Tomcat是目前比较流行的Web服务器之一。作为一个开源和小型的轻量级应用服务器,Tomcat 易于使用,便于部署,但Tomcat本身是一个非常复杂的系统,包含了很多功能模块。这些功能模块构成了Tomcat的核心结构。本书从最基本的HTTP请求开......一起来看看 《深入剖析Tomcat》 这本书的介绍吧!