### train_model.py ###

#!/usr/bin/env python
# coding=utf-8 import codecs
import simplejson as json
import numpy as np
import pandas as pd
from keras.models import Sequential, load_model
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.preprocessing import sequence
from keras.utils import to_categorical
from keras.layers import *
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.externals import joblib
import logging
import re
import pickle as pkl logging.basicConfig(level=logging.INFO, format='%(asctime)s %(filename)s: %(message)s', datefmt='%Y-%m-%d %H:%M', filename='log/train_model.log', filemode='a+') ngram_range = 1
max_features = 6500
maxlen = 120 fw = open('error_line_test.txt', 'wb') DIRTY_LABEL = re.compile('\W+')
# set([u' Business ',u' On a commission basis ',u' The construction of the ',u' planning ',u' Design ',u' sales ',u' With the exception of ',u' retail ',u' food '])
STOP_WORDS = pkl.load(open('./data/stopwords.pkl')) def load_data(fname='data/12315_industry_business_train.csv', nrows=None):
"""
Load training data
"""
data, labels = [], []
char2idx = json.load(open('data/char2idx.json'))
used_keys = set(['name', 'business'])
df = pd.read_csv(fname, encoding='utf-8', nrows=nrows)
for idx, item in df.iterrows():
item = item.to_dict()
line = ''
for key, value in item.iteritems():
if key in used_keys:
line += key+value data.append([char2idx[char] for char in line if char in char2idx])
labels.append(item['label']) le = LabelEncoder()
logging.info('%d nb_class: %s' % (len(np.unique(labels)), str(np.unique(labels))))
onehot_label = to_categorical(le.fit_transform(labels))
joblib.dump(le, 'model/tgind_labelencoder.h5')
x_train, x_test, y_train, y_test = train_test_split(data, onehot_label, test_size=0.1)
return (x_train, y_train), (x_test, y_test) def create_ngram_set(input_list, ngram_value=2):
return set(zip(*[input_list[i:] for i in range(ngram_value)])) def add_ngram(sequences, token_indice, ngram_range=2):
"""
Augment the input list of sequences by appending n-grams values """
new_sequences = []
for input_list in sequences:
new_list = input_list[:]
for i in range(len(new_list) - ngram_range + 1):
for ngram_value in range(2, ngram_range+1):
ngram = tuple(new_list[i:i+ngram_value])
if ngram in token_indice:
new_list.append(token_indice[ngram])
new_sequences.append(new_list) return new_sequences (x_train, y_train), (x_test, y_test) = load_data()
nb_class = y_train.shape[1] logging.info('x_train size: %d' % (len(x_train)))
logging.info('x_test size: %d' % (len(x_test)))
logging.info('x_train sent average len: %.2f' % (np.mean(list(map(len, x_train)))))
print 'x_train sent avg length: %.2f' % (np.mean(list(map(len, x_train)))) if ngram_range>1:
print 'add {}-gram features'.format(ngram_range)
ngram_set = set()
for input_list in x_train:
for i in range(2, ngram_range+1):
set_of_ngram = create_ngram_set(input_list, ngram_value=i)
ngram_set.update(set_of_ngram) start_index = max_features + 1
token_indice = {v: k+start_index for k,v in enumerate(ngram_set)}
indice_token = {token_indice[k]: k for k in token_indice} max_features = np.max(list(indice_token.keys()))+1 x_train = add_ngram(x_train, token_indice, ngram_range)
x_test = add_ngram(x_test, token_indice, ngram_range) print 'pad sequences (samples x time)'
x_train = sequence.pad_sequences(x_train, maxlen=maxlen, padding='post', truncating='post')
x_test = sequence.pad_sequences(x_test, maxlen=maxlen, padding='post', truncating='post') logging.info('x_train.shape: %s' % (str(x_train.shape))) print 'build model...' def cal_accuracy(x_test, y_test):
"""
Accuracy statistics
"""
y_test = np.argmax(y_test, axis=1)
y_pred = model.predict_classes(x_test)
correct_cnt = np.sum(y_pred==y_test)
return float(correct_cnt)/len(y_test) DEBUG = False
if DEBUG:
model = Sequential()
model.add(Embedding(max_features, 200, input_length=maxlen))
model.add(GlobalAveragePooling1D())
model.add(Dropout(0.3))
model.add(Dense(nb_class, activation='softmax'))
else:
model = load_model('./model/tgind_dalei.h5') #model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
earlystop = EarlyStopping(monitor='val_loss', patience=8)
checkpoint = ModelCheckpoint(filepath='./model/tgind_dalei.h5', monitor='val_loss', save_best_only=True, save_weights_only=False) model.fit(x_train, y_train, shuffle=True, batch_size=64, epochs=80, validation_split=0.1, callbacks=[checkpoint, earlystop]) loss, acc = model.evaluate(x_test, y_test)
print '\n\nlast model: loss', loss
print 'acc', acc model = load_model('model/tgind_dalei.h5')
loss, acc = model.evaluate(x_test, y_test)
print '\n\n cur best model: loss', loss
print 'accuracy', acc
logging.info('loss: %.4f ;accuracy: %.4f' % (loss, acc)) logging.info('\nmodel acc: %.4f' % acc)
logging.info('\nmodel config:\n %s' % model.get_config())

### test_model.py ###

#!/usr/bin/env python
# coding=utf-8 import matplotlib.pyplot as plt
from api_tgind import TgIndustry
import pandas as pd
import codecs
import json
from collections import OrderedDict ########### Calculate the accuracy based on the threshold ########### def cal_model_acc(model, fname='./data/industry_dalei_test_sample2k.txt', nrows=None):
"""
Load data , And calculate before 5 The accuracy of
"""
res = {}
res['y_pred'] = []
res['y_true'] = []
with codecs.open(fname, encoding='utf-8') as fr:
for idx, line in enumerate(fr):
tokens = line.strip().split()
if len(tokens)>3:
tokens, label = tokens[:-1], tokens[-1].replace('__labe__', '')
tmp = {}
tmp['business'] = ''.join(tokens)
res['y_pred'].append(model.predict(tmp))
res['y_true'].append(label)
if nrows and idx>nrows:
break
json.dump(res, codecs.open('log/total_acc_output.json', 'wb', encoding='utf-8'))
return res def cal_model_acc2(model, fname='data/test_12315_industry_business_sample100.csv', nrows=None):
"""
Direct basis csv Predicted results
"""
res = {}
res['y_pred'] = []
res['y_true'] = []
df = pd.read_csv(fname, encoding='utf-8')
for idx, item in df.iterrows():
try:
res['y_pred'].append(model.predict(item.to_dict()))
except Exception as e:
print e
print idx
print item['name']
continue
res['y_true'].append(item['label']) if nrows and idx>nrows:
break
json.dump(res, codecs.open('log/total_acc_output.json', 'wb', encoding='utf-8'))
return res def get_model_acc_menlei(res, topk=5, threhold=0.8):
"""
The accuracy of the model is calculated according to the threshold
"""
correct_cnt, total_cnt = 0, 0
for idx, y_pred in enumerate(res['y_pred']):
y_pred_tuple = sorted(y_pred.iteritems(), key=lambda x:float(x[1]), reverse=True) # Probability ranking
y_pred = OrderedDict()
for c, s in y_pred_tuple:
y_pred[c] = float(s) if y_pred.values()[0] > threhold: # The maximum category probability is greater than the threshold threhold
if res['y_true'][idx][0] in map(lambda x:x[0], y_pred.keys()[:topk]):
correct_cnt += 1
total_cnt += 1
acc = float(correct_cnt)/total_cnt
recall = float(total_cnt)/len(res['y_true'])
return acc, recall def get_model_acc_dalei(res, topk=5, threhold=0.8):
"""
The accuracy of the model is calculated according to the threshold
"""
correct_cnt, total_cnt = 0, 0
for idx, y_pred in enumerate(res['y_pred']):
y_pred_tuple = sorted(y_pred.iteritems(), key=lambda x:float(x[1]), reverse=True) # Probability ranking
y_pred = OrderedDict()
for c, s in y_pred_tuple:
y_pred[c] = float(s) if y_pred.values()[0] >= threhold: # The maximum category probability is greater than the threshold threhold
if res['y_true'][idx] in y_pred.keys()[:topk]:
correct_cnt += 1
total_cnt += 1 acc = float(correct_cnt)/total_cnt
recall = float(total_cnt)/len(res['y_true'])
return acc, recall def plot_accuracy(title, df, number):
"""
Accuracy mapping
"""
for topk in range(1, 5):
tmpdf = df[df.topk==topk]
fig = plt.figure()
ax1 = fig.add_subplot(111)
plt.subplots_adjust(top=0.85)
ax1.plot(tmpdf['threhold'], tmpdf['accuracy'], 'ro-', label='accuracy')
# ax2 = ax1.twinx()
ax1.plot(tmpdf['threhold'], tmpdf['recall'], 'g^-', label='recall')
ax1.set_ylim(0.3, 1.0)
ax1.legend(loc=3)
ax1.set_xlabel('threhold')
plt.grid(True)
plt.title('%s Industry Classify Result\n topk=%d, number=%d\n' % (title, topk, number))
plt.savefig('log/test_%s_acc_topk%d.png' % (title, topk))
print topk, 'done!' def gen_plot_data(model_acc, ctype='2nd'):
"""
Generate graph data
"""
res = {}
res['accuracy'] = []
res['threhold'] = []
res['topk'] = []
res['recall'] = []
for topk in range(1,5):
for threhold in range(0, 10):
threhold = 0.1*threhold
if ctype == '1st':
acc, recall = get_model_acc_menlei(model_acc, topk, threhold)
else:
acc, recall = get_model_acc_dalei(model_acc, topk, threhold)
res['accuracy'].append(acc)
res['recall'].append(recall)
res['threhold'].append(threhold)
res['topk'].append(topk)
print ctype, topk, acc
json.dump(res, open('log/test_model_threshold_%s.log' % ctype, 'wb'))
df = pd.DataFrame(res)
df.to_csv('log/test_model_result_%s.csv' % ctype, index=False)
plot_accuracy(ctype, df, len(model_acc['y_true']))
return df if __name__=='__main__': model = TgIndustry()
# model_acc = cal_model_acc2(model, fname='data/test_12315_industry_business_sample100.csv')
model_acc = json.load(codecs.open('log/total_acc_output_12315.json', encoding='utf-8'))
gen_plot_data(model_acc, '1st')
gen_plot_data(model_acc, '2nd')

### api_tgind.py ###

#!/usr/bin/env python
# coding=utf-8 import numpy as np
import codecs
import simplejson as json
from keras.models import load_model
from keras.preprocessing import sequence
from sklearn.externals import joblib
from collections import OrderedDict
import pickle as pkl
import re, os
import jieba
import time """
Industry classification calls Api __author__: jkmiao
__date__: 2017-07-05 """ class TgIndustry(object): def __init__(self, model_path='model/tgind_dalei_acc76.h5'): base_path = os.path.dirname(__file__)
model_path = os.path.join(base_path, model_path) # Load the pre trained model
self.model = load_model(model_path)
# load labelEncoder
self.le = joblib.load(os.path.join(base_path, './model/tgind_labelencoder.h5'))
# Load the character map
self.char2idx = json.load(open(os.path.join(base_path, 'data/char2idx.json')))
# Load stop words
# self.stop_words = set([line.strip() for line in codecs.open('./data/stopwords.txt', encoding='utf-8')])
self.stop_words = pkl.load(open(os.path.join(base_path, './data/stopwords.pkl')))
# Load the final number and name mapping of the category
self.menlei_label2name = json.load(open(os.path.join(base_path, 'data/menlei_label2name.json'))) # First level classification
self.dalei_label2name = json.load(open(os.path.join(base_path, 'data/dalei_label2name.json'))) # Secondary classification def predict(self, company_info, topk=2, firstIndustry=False, final_name=False):
"""
:type company_info: Company information
:rtype business: str: Corresponding label
"""
line = ''
for key, value in company_info.iteritems():
if key in ['name', 'business']: # Company information , At present, the company name and business scope
line += company_info[key] if not isinstance(line, unicode):
line = line.decode('utf-8') # Remove the stop words from the sentence
line = ''.join([token for token in jieba.cut(line) if token not in self.stop_words])
data = [self.char2idx[char] for char in line if char in self.char2idx]
data = sequence.pad_sequences([data], maxlen=100, padding='post', truncating='post')
y_pred_proba = self.model.predict(data, verbose=0)
y_pred_idx_list = [c[-topk:][::-1] for c in np.argsort(y_pred_proba, axis=-1)][0]
res = OrderedDict()
for y_pred_idx in y_pred_idx_list:
y_pred_label = self.le.inverse_transform(y_pred_idx)
if final_name:
y_pred_label = self.dalei_label2name[y_pred_label]
if firstIndustry:
res[y_pred_label[0]] = round(y_pred_proba[0, y_pred_idx], 3) # Probability retention 3 Decimal place
res[y_pred_label] = round(y_pred_proba[0, y_pred_idx], 3) # Probability retention 3 Decimal place
return res if __name__ == '__main__': DIRTY_LABEL = re.compile('\W+')
test = TgIndustry()
cnt, total_cnt = 0, 0
start_time = time.time()
fw2 = codecs.open('./output/industry_dalei_test_sample2k_error.txt', 'wb', encoding='utf-8')
with codecs.open('./data/industry_dalei_test_sample2k.txt', encoding='utf-8') as fr:
for idx, line in enumerate(fr):
tokens = line.strip().split()
if len(tokens)>3:
tokens, label = tokens[:-1], tokens[-1].replace('__label__', '')
if len(label) not in [2, 3] or DIRTY_LABEL.search(label):
print 'error line:'
print idx, line, label
continue
tmp = {}
tmp['business'] = ''.join(tokens)
y_pred = test.predict(tmp, topk=1)
if label in y_pred:
cnt += 1
elif y_pred.values()[0] < 0.3:
print 'error: ', ''.join(tokens), y_pred, 'y_true:', label
fw2.write(''.join(tokens))
total_cnt +=1
print label
print json.dumps(y_pred, ensure_ascii=False)
print idx, '=='*20, float(cnt)/total_cnt
if idx>200:
break
print 'avg cost time:', float(time.time()-start_time)/idx

be based on keras Of fasttext More articles on short text categorization

  1. be based on keras in IMDB Text classification demo

      This time demo The theme is to use keras Yes IMDB Film reviews are classified into text categories : import tensorflow as tf from tensorflow import keras import numpy a ...

  2. [ Deep application ]&#183; Open source of the first China ECG intelligence competition Baseline( be based on Keras val_acc: 0.88)

    [ Deep application ]· Open source of the first China ECG intelligence competition Baseline( be based on Keras val_acc: 0.88) Personal home page --> https://xiaosongshine.github.io/ project g ...

  3. ( turn !) utilize Keras Realize image classification and color classification

    2018-07-19 All Google slag with a little modification Let's make do with it I suggest you read the original Click to harvest the original The sample files used in this process multi-output-classification You can click download . ...

  4. be based on Text-CNN Model of Chinese text classification practice Rukawa Kaede Published in AI Planet subscription

    Text-CNN 1. Text classification In a twinkling, the student career was over , There is just a period of leisure when I am working at home , You can summarize some knowledge points that you were interested in . In this paper, NLP The core process of Chinese text classification task is introduced systematically , At the end of this paper, we give a new method based on T ...

  5. be based on Text-CNN Model of Chinese text classification practice

    Text-CNN 1. Text classification In a twinkling, the student career was over , There is just a period of leisure when I am working at home , You can summarize some knowledge points that you were interested in . In this paper, NLP The core process of Chinese text classification task is introduced systematically , At the end of this paper, we give a new method based on T ...

  6. be based on keras Realization of Chinese entity recognition

    1. brief introduction NER(Named Entity Recognition, Named entity recognition ) Also known as proper name recognition , It's a common task in natural language processing , It's very widely used . Named entity usually refers to the text with special meaning or strong reference ...

  7. Based on a jQuery Imitate Taobao red classification navigation

    Today I'd like to share with you a product based on jQuery Imitate Taobao red classification navigation . This category navigation is suitable for browsers :IE8.360.FireFox.Chrome.Safari.Opera. Maxthon . sogou . The Windows of the world . The renderings are as follows : The online preview     ...

  8. Chinese-Text-Classification, Convolution neural network based on Tensorflow The realization of Chinese text classification .

    Convolution neural network based on Tensorflow The realization of Chinese text classification Project address : https://github.com/fendouai/Chinese-Text-Classification Welcome to ask questions :ht ...

  9. [AI Development ]centOS7.5 Based on keras/tensorflow Building a deep learning environment

    This article introduces in detail centOS7.5 Based on keras/tensorflow Deep learning environment , This environment can be used in actual production . I am very skilled now linux(Ubuntu/centOS/openSUSE).wind ...

Random recommendation

  1. How to be in MFC Dialog box application ColorPicker Control

    In daily application development , When it comes to curve drawing , In order to distinguish multiple curves of different types , They often need to be assigned different colors . Let's make a simple record here today , How to realize and use ColorPicker Control . The program uses 4 The order of the documents is :C ...

  2. JAVA Custom status code

    Return the information class (ResponseInfo): public class ResponseInfo { public static final String Status = "status&qu ...

  3. Through lexical analysis to achieve that C The header file included in the program

    When reading the source code of some programs , I hope to find out which header files are included in the source code immediately , To determine if you need to manually add... For special functions #include. With the help of flex This function is realized by lexical analysis of , It's essentially a match to regular expressions . Be careful ...

  4. nodejs save A pit I met

    Mixed types because there are no specific constraints , So it can be modified at will , Once the prototype is modified , You have to call markModified() >>> person.anything = {x:[3,4,{y:'change ...

  5. WCF Message mode of

    request / Respond to : Default behavior for all operations , stay WSDL Manifested as Input/Output Elements . One_Way. stay WSDL There are only Input, No response (Output), So there's no exception report . A one-way operation occurs only at the moment the call is made ...

  6. h.264 Grammatical structure analysis

    NAL Unit Stream Network Abstraction Layer, abbreviation NAL. h.264 Put the original yuv The file is encoded into a stream file , The generated stream file is NAL Unit flow (NAL unit Stre ...

  7. web register/validation/login system flowchart

    I spent several days on building a system about this. And make it work well with serveral thousand l ...

  8. dump File generation and debugging (VS2008)

    To sum up dump File generation and debugging methods : 1: use SetUnhandledExceptionFilter Catch unhandled exceptions , Include header file <windows.h>. The function prototype is : LPTOP_LEVEL ...

  9. Section 1.1

    Your Ride Is Here /* PROG:ride LANG:C++ */ #include <iostream> #include <cstdio> #includ ...

  10. requests- All exceptions are classified

    IOError RequestException HTTPError(RequestException) UnrewindableBodyError(RequestException) RetryEr ...