探索flutter框架开发的app在移动应用市场的潜力与挑战
790
2022-11-01
DeepGBM:GBDT针对在线预测任务发布的深度学习框架
DeepGBM
Implementation for the paper "DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks", which has been accepted by KDD'2019 as an Oral Paper, in the Research Track. You can get more information from the video and the paper.
If you find this code useful in your research, please cite the paper:
Guolin Ke, Zhenhui Xu, Jia Zhang, Jiang Bian, and Tie-Yan Liu. "DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks." In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2019: 384-394.
Brief Introduction
This repo is built for the experimental codes in our paper, containing all the data preprocessing, baseline models implementation and proposed model implementation (full codes here). For quick start, here we only show the codes related to our model. For GBDT based model, our implementation is based on LightGBM. For NN based model, our implementation is based on pytorch.
There are three main folders in the project, data is for data storage, preprocess is the folder containing feature selection and encoding, models contains all the implementation codes of the proposed model. For more detailed experiments codes, refer to the experiments folder.
Besides, main.py is the entry code file for our model. Besides, data_helpers.py contains the data loader, helper.py contains the general training and testing logic for NN. train_models.py is for the specific training process of the model. In models, there are some implementations of main models. tree_model_interpreter.py is used for interpreting the trained GBDT's structure.
Environment Setting
The main dependency is shown as follows:
Python==3.6.6LightGBM==2.2.1Pytorch==0.4.1Sklearn==0.19.2
Quick Start
All the datasets should be converted into .csv files first and then processed by encoders in preprocess. The features used for each dataset could be seen in preprocess/encoding_*.py, the main function specifically.
To run DeepGBM, after the above step, you will prepare your data in .npy format. Then we can use the function in data_helpers.py to load its numerical part and categorical part:
num_data = dh.load_data(args.data+'_num')cate_data = dh.load_data(args.data+'_cate')# following is designed for faster catNN inputscate_data = dh.trans_cate_data(cate_data)
On the contrary, if you run GBDT2NN or CatNN only, you can only feed the numerical data or categorical data into the model. Then, you can call the functions in train_models.py like:
train_GBDT2NN(args, num_data, plot_title)# ortrain_DEEPGBM(args, num_data, cate_data, plot_title)
Thanks for your visiting, and if you have any questions, please new an issue.
版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。
发表评论
暂时没有评论,来抢沙发吧~