spell checker
Hello, we need to use pyqt5 and python in this program. Lets start
Firstly please clon this github repo
https://github.com/vuptran/deep-spell-checkr

# -*- coding: utf-8 -*-
# Form implementation generated from reading ui file 'pencere.ui'
#
# Created by: PyQt5 UI code generator 5.15.4
# Batuhan ökmen as u now flavves
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again. Do not edit this file unless you know what you are doing.
import textract
from PyQt5 import QtCore, QtGui, QtWidgets
import easygui
from nltk.tokenize import RegexpTokenizer
from nltk import tokenize
from collections import Counter
import re
import os
import numpy as np
import unidecode
import keras.backend as K
from tensorflow.keras import models
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import LSTM, Bidirectional
from tensorflow.keras.layers import Dense, Flatten
from keras.layers import Input
from tensorflow.keras.layers import TimeDistributed
from tensorflow.keras.layers import RepeatVector
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import plot_model
import matplotlib.pyplot as plt
from pydotplus import graphviz
from tensorflow.keras.layers import Lambda,Activation
from tensorflow.keras import backend as K
from keras.preprocessing.sequence import pad_sequences
from keras.optimizers import RMSprop
from keras.preprocessing.text import Tokenizer
from keras.models import Model, load_model
from keras.layers import Input
from model import seq2seq
from model import truncated_acc, truncated_loss
SOS = '\t'
EOS = '*'
ALPHABETCHARS = list('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ ')
IGNOREDCHARS = '[#$%"\+@<=>!&,-.?:;()*\[\]^_`{|}~/\d\t\n\r\x0b\x0c]'
error_rate = 0.7
reverse = True
model_path = './models/seq2seq.h5'
hidden_size = 512
sample_mode = 'argmax'
data_path = './data'
books = ['1.txt','2.txt','3.txt','4.txt']
def read_text(data_path, list_of_books):
text = ''
for book in list_of_books:
file_path = os.path.join(data_path, book)
strings = unidecode.unidecode(open(file_path).read())
text += strings + ' '
return text
class CharacterTable(object):
def __init__(self, chars):
self.chars = sorted(set(chars))
self.char2index = dict((c, i) for i, c in enumerate(self.chars))
self.index2char = dict((i, c) for i, c in enumerate(self.chars))
self.size = len(self.chars)
def encode(self, C, nb_rows):
try:
x = np.zeros((nb_rows, len(self.chars)), dtype=np.float32)
for i, c in enumerate(C):
x[i, self.char2index[c]] = 1.0
return x
except:
print("uzun")
def decode(self, x, calc_argmax=True):
if calc_argmax:
indices = x.argmax(axis=-1)
else:
indices = x
chars = ''.join(self.index2char[ind] for ind in indices)
return indices, chars
def sample_multinomial(self, preds, temperature=1.0):
preds = np.reshape(preds, len(self.chars)).astype(np.float64)
preds = np.log(preds) / temperature
exp_preds = np.exp(preds)
preds = exp_preds / np.sum(exp_preds)
probs = np.random.multinomial(1, preds, 1)
index = np.argmax(probs)
char = self.index2char[index]
return index, char
def add_speling_erors(token, error_rate):
assert(0.0 <= error_rate < 1.0)
if len(token) < 3:
return token
rand = np.random.rand()
prob = error_rate / 4.0
if rand < prob:
random_char_index = np.random.randint(len(token))
token = token[:random_char_index] + np.random.choice(ALPHABETCHARS) \
+ token[random_char_index + 1:]
elif prob < rand < prob * 2:
random_char_index = np.random.randint(len(token))
token = token[:random_char_index] + token[random_char_index + 1:]
elif prob * 2 < rand < prob * 3:
random_char_index = np.random.randint(len(token))
token = token[:random_char_index] + np.random.choice(ALPHABETCHARS) \
+ token[random_char_index:]
elif prob * 3 < rand < prob * 4:
random_char_index = np.random.randint(len(token) - 1)
token = token[:random_char_index] + token[random_char_index + 1] \
+ token[random_char_index] + token[random_char_index + 2:]
else:
# hata yok
pass
return token
def transform(tokens, maxlen, error_rate=0.3, shuffle=True):
if shuffle:
print(' random data olusturuluyor...')
np.random.shuffle(tokens)
encoder_tokens = []
decoder_tokens = []
target_tokens = []
for token in tokens:
encoder = add_speling_erors(token, error_rate=error_rate)
encoder += EOS * (maxlen - len(encoder))
encoder_tokens.append(encoder)
decoder = SOS + token
decoder += EOS * (maxlen - len(decoder))
decoder_tokens.append(decoder)
target = decoder[1:]
target += EOS * (maxlen - len(target))
target_tokens.append(target)
assert(len(encoder) == len(decoder) == len(target))
return encoder_tokens, decoder_tokens, target_tokens
def tokenize(text):
tokens = [re.sub(IGNOREDCHARS, '', token)
for token in re.split("[-\n ]", text)]
return tokens
def decode_sequences(inputs, targets, input_ctable, target_ctable,
maxlen, reverse, encoder_model, decoder_model,
nb_examples, sample_mode='argmax', random=True):
input_tokens = []
target_tokens = []
if random:
indices = np.random.randint(0, len(inputs), nb_examples)
else:
indices = range(nb_examples)
for index in indices:
input_tokens.append(inputs[index])
target_tokens.append(targets[index])
input_sequences = batch(input_tokens, maxlen, input_ctable,
nb_examples, reverse)
input_sequences = next(input_sequences)
states_value = encoder_model.predict(input_sequences)
target_sequences = np.zeros((nb_examples, 1, target_ctable.size))
target_sequences[:, 0, target_ctable.char2index[SOS]] = 1.0
decoded_tokens = [''] * nb_examples
for _ in range(maxlen):
char_probs, h, c = decoder_model.predict(
[target_sequences] + states_value)
target_sequences = np.zeros((nb_examples, 1, target_ctable.size))
sampled_chars = []
for i in range(nb_examples):
if sample_mode == 'argmax':
next_index, next_char = target_ctable.decode(
char_probs[i], calc_argmax=True)
elif sample_mode == 'multinomial':
next_index, next_char = target_ctable.sample_multinomial(
char_probs[i], temperature=0.5)
else:
raise Exception(
"`sample_mode` accepts `argmax` or `multinomial`.")
decoded_tokens[i] += next_char
sampled_chars.append(next_char)
target_sequences[i, 0, next_index] = 1.0
stop_char = set(sampled_chars)
if len(stop_char) == 1 and stop_char.pop() == EOS:
break
states_value = [h, c]
input_tokens = [re.sub('[%s]' % EOS, '', token)
for token in input_tokens]
target_tokens = [re.sub('[%s]' % EOS, '', token)
for token in target_tokens]
decoded_tokens = [re.sub('[%s]' % EOS, '', token)
for token in decoded_tokens]
return input_tokens, target_tokens, decoded_tokens
def batch(tokens, maxlen, ctable, batch_size=128, reverse=False):
def generate(tokens, reverse):
while(True):
for token in tokens:
if reverse:
token = token[::-1]
yield token
token_iterator = generate(tokens, reverse)
data_batch = np.zeros((batch_size, maxlen, ctable.size),
dtype=np.float32)
while(True):
for i in range(batch_size):
token = next(token_iterator)
data_batch[i] = ctable.encode(token, maxlen)
yield data_batch
def restore_model(path_to_full_model, hidden_size):
model = load_model(path_to_full_model, custom_objects={
'truncated_acc': truncated_acc, 'truncated_loss': truncated_loss})
encoder_inputs = model.input[0] # encoder_data
encoder_lstm1 = model.get_layer('encoder_lstm_1')
encoder_lstm2 = model.get_layer('encoder_lstm_2')
encoder_outputs = encoder_lstm1(encoder_inputs)
_, state_h, state_c = encoder_lstm2(encoder_outputs)
encoder_states = [state_h, state_c]
encoder_model = Model(inputs=encoder_inputs, outputs=encoder_states)
decoder_inputs = model.input[1] # decoder_data
decoder_state_input_h = Input(shape=(hidden_size,))
decoder_state_input_c = Input(shape=(hidden_size,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_lstm = model.get_layer('decoder_lstm')
decoder_outputs, state_h, state_c = decoder_lstm(
decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_softmax = model.get_layer('decoder_softmax')
decoder_outputs = decoder_softmax(decoder_outputs)
decoder_model = Model(inputs=[decoder_inputs] + decoder_states_inputs,
outputs=[decoder_outputs] + decoder_states)
return encoder_model, decoder_model
def check_if_equal(list_1, list_2):
if len(list_1) != len(list_2):
return False
return sorted(list_1) == sorted(list_2)
global uwu_cumle
def CorrectSentences(sentence):
text = read_text(data_path, books)
vocab = tokenize(text)
vocab = list(filter(None, set(vocab)))
maxlen = max([len(token) for token in vocab]) + 2
train_encoder, train_decoder, train_target = transform(
vocab, maxlen, error_rate=error_rate, shuffle=False)
tokens = tokenize(sentence)
tokens = list(filter(None, tokens))
nb_tokens = len(tokens)
misspelled_tokens, _, target_tokens = transform(
tokens, maxlen, error_rate=error_rate, shuffle=False)
input_chars = set(' '.join(train_encoder))
target_chars = set(' '.join(train_decoder))
input_ctable = CharacterTable(input_chars)
target_ctable = CharacterTable(target_chars)
encoder_model, decoder_model = restore_model(model_path, hidden_size)
input_tokens, target_tokens, decoded_tokens = decode_sequences(
misspelled_tokens, target_tokens, input_ctable, target_ctable,
maxlen, reverse, encoder_model, decoder_model, nb_tokens,
sample_mode=sample_mode, random=False)
print(misspelled_tokens)
print(input_tokens)
print(target_tokens)
print(decoded_tokens)
print(error_rate)
print('Input cümlesi: ', ' '.join([token for token in input_tokens]))
print('Decoded cümle:', ' '.join([token for token in decoded_tokens]))
print('Target cümlesi: ', ' '.join([token for token in target_tokens]))
print('Maksimum kelime uzunluğu:',maxlen)
result1 = check_if_equal(tokens,target_tokens)
result2 = check_if_equal(tokens,decoded_tokens)
uwu_cumle=[token for token in decoded_tokens]
print(uwu_cumle)
for ndx in range(len(tokens)):
if((not tokens[ndx] in vocab) and (target_tokens[ndx] in vocab)):
print(decoded_tokens[ndx])
tokens[ndx] = target_tokens[ndx]
if((not tokens[ndx] in vocab) and (decoded_tokens[ndx] in vocab)):
print(decoded_tokens[ndx])
tokens[ndx] = decoded_tokens[ndx]
ndx = ndx-1
return target_tokens, decoded_tokens
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.setWindowModality(QtCore.Qt.ApplicationModal)
MainWindow.setEnabled(True)
MainWindow.resize(800, 600)
sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed)
sizePolicy.setHorizontalStretch(0)
sizePolicy.setVerticalStretch(0)
sizePolicy.setHeightForWidth(MainWindow.sizePolicy().hasHeightForWidth())
MainWindow.setSizePolicy(sizePolicy)
MainWindow.setMinimumSize(QtCore.QSize(800, 600))
MainWindow.setMaximumSize(QtCore.QSize(800, 600))
MainWindow.setMouseTracking(False)
MainWindow.setTabletTracking(False)
MainWindow.setContextMenuPolicy(QtCore.Qt.NoContextMenu)
MainWindow.setAcceptDrops(False)
icon = QtGui.QIcon.fromTheme("0")
MainWindow.setWindowIcon(icon)
MainWindow.setWindowOpacity(1.0)
MainWindow.setAutoFillBackground(False)
MainWindow.setAnimated(True)
MainWindow.setDocumentMode(False)
MainWindow.setTabShape(QtWidgets.QTabWidget.Rounded)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.widget = QtWidgets.QWidget(self.centralwidget)
self.widget.setGeometry(QtCore.QRect(0, 0, 796, 591))
self.widget.setObjectName("widget")
self.gridLayout = QtWidgets.QGridLayout(self.widget)
self.gridLayout.setContentsMargins(8, 8, 8, 8)
self.gridLayout.setHorizontalSpacing(7)
self.gridLayout.setObjectName("gridLayout")
self.verticalLayout_4 = QtWidgets.QVBoxLayout()
self.verticalLayout_4.setObjectName("verticalLayout_4")
self.verticalLayout = QtWidgets.QVBoxLayout()
self.verticalLayout.setObjectName("verticalLayout")
self.label = QtWidgets.QLabel(self.widget)
font = QtGui.QFont()
font.setPointSize(10)
self.label.setFont(font)
self.label.setObjectName("label")
self.verticalLayout.addWidget(self.label)
self.textEdit = QtWidgets.QTextEdit(self.widget)
font = QtGui.QFont()
font.setPointSize(14)
self.textEdit.setFont(font)
self.textEdit.setObjectName("textEdit")
self.verticalLayout.addWidget(self.textEdit)
self.verticalLayout_4.addLayout(self.verticalLayout)
self.pushButton = QtWidgets.QPushButton(self.widget)
font = QtGui.QFont()
font.setPointSize(12)
self.pushButton.setFont(font)
self.pushButton.setObjectName("pushButton")
self.verticalLayout_4.addWidget(self.pushButton)
self.gridLayout.addLayout(self.verticalLayout_4, 0, 0, 1, 1)
self.verticalLayout_6 = QtWidgets.QVBoxLayout()
self.verticalLayout_6.setObjectName("verticalLayout_6")
self.verticalLayout_3 = QtWidgets.QVBoxLayout()
self.verticalLayout_3.setObjectName("verticalLayout_3")
self.label_3 = QtWidgets.QLabel(self.widget)
font = QtGui.QFont()
font.setPointSize(10)
self.label_3.setFont(font)
self.label_3.setObjectName("label_3")
self.verticalLayout_3.addWidget(self.label_3)
self.textEdit_3 = QtWidgets.QTextEdit(self.widget)
font = QtGui.QFont()
font.setPointSize(14)
self.textEdit_3.setFont(font)
self.textEdit_3.setObjectName("textEdit_3")
self.verticalLayout_3.addWidget(self.textEdit_3)
self.verticalLayout_6.addLayout(self.verticalLayout_3)
self.listWidget = QtWidgets.QListWidget(self.widget)
self.listWidget.setAutoScrollMargin(16)
self.listWidget.setDragEnabled(False)
self.listWidget.setSelectionMode(QtWidgets.QAbstractItemView.MultiSelection)
self.listWidget.setTextElideMode(QtCore.Qt.ElideRight)
self.listWidget.setMovement(QtWidgets.QListView.Free)
self.listWidget.setFlow(QtWidgets.QListView.TopToBottom)
self.listWidget.setProperty("isWrapping", False)
self.listWidget.setResizeMode(QtWidgets.QListView.Fixed)
self.listWidget.setLayoutMode(QtWidgets.QListView.SinglePass)
self.listWidget.setGridSize(QtCore.QSize(24, 28))
self.listWidget.setViewMode(QtWidgets.QListView.ListMode)
self.listWidget.setUniformItemSizes(False)
self.listWidget.setSelectionRectVisible(True)
self.listWidget.setObjectName("listWidget")
self.verticalLayout_6.addWidget(self.listWidget)
self.pushButton_3 = QtWidgets.QPushButton(self.widget)
font = QtGui.QFont()
font.setPointSize(12)
self.pushButton_3.setFont(font)
self.pushButton_3.setObjectName("pushButton_3")
self.verticalLayout_6.addWidget(self.pushButton_3)
self.gridLayout.addLayout(self.verticalLayout_6, 0, 2, 1, 1)
self.verticalLayout_5 = QtWidgets.QVBoxLayout()
self.verticalLayout_5.setObjectName("verticalLayout_5")
self.verticalLayout_2 = QtWidgets.QVBoxLayout()
self.verticalLayout_2.setObjectName("verticalLayout_2")
self.label_2 = QtWidgets.QLabel(self.widget)
font = QtGui.QFont()
font.setPointSize(10)
self.label_2.setFont(font)
self.label_2.setObjectName("label_2")
self.verticalLayout_2.addWidget(self.label_2)
self.textEdit_2 = QtWidgets.QTextEdit(self.widget)
font = QtGui.QFont()
font.setPointSize(14)
self.textEdit_2.setFont(font)
self.textEdit_2.setObjectName("textEdit_2")
self.verticalLayout_2.addWidget(self.textEdit_2)
self.verticalLayout_5.addLayout(self.verticalLayout_2)
self.pushButton_2 = QtWidgets.QPushButton(self.widget)
font = QtGui.QFont()
font.setPointSize(12)
self.pushButton_2.setFont(font)
self.pushButton_2.setObjectName("pushButton_2")
self.verticalLayout_5.addWidget(self.pushButton_2)
self.gridLayout.addLayout(self.verticalLayout_5, 0, 1, 1, 1)
self.pushButton_5 = QtWidgets.QPushButton(self.widget)
font = QtGui.QFont()
font.setPointSize(12)
self.pushButton_5.setFont(font)
self.pushButton_5.setObjectName("pushButton_5")
self.gridLayout.addWidget(self.pushButton_5, 1, 2, 1, 1)
self.pushButton_4 = QtWidgets.QPushButton(self.widget)
font = QtGui.QFont()
font.setPointSize(12)
self.pushButton_4.setFont(font)
self.pushButton_4.setObjectName("pushButton_4")
self.gridLayout.addWidget(self.pushButton_4, 1, 1, 1, 1)
MainWindow.setCentralWidget(self.centralwidget)
self.retranslateUi(MainWindow)
self.listWidget.setCurrentRow(-1)
self.pushButton.clicked.connect(self.metinYukle)
self.pushButton_2.clicked.connect(self.metniDuzelt)
self.pushButton_4.clicked.connect(self.kaydet)
self.pushButton_5.clicked.connect(self.iptal)
self.pushButton_3.clicked.connect(self.tumunuSec)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
self.pushButton_4.setEnabled(False)
self.pushButton_5.setEnabled(False)
self.listWidget.itemClicked.connect(self.degisim)
self.metin = ""
self.dosyaURL = "adc.adc"
self.tahminler = []
self.duzeltilecek = []
self.toplam = []
self.metnimiz = []
def degisim(self,MainWindow):
items = self.listWidget.selectedItems()
self.duzeltilecek = []
for i in range(len(items)):
self.duzeltilecek.append(str(self.listWidget.selectedItems()[i].text()))
metin = ""
for word in self.metnimiz:
if(word==self.tahminler[self.metnimiz.index(word)]):
metin += word + " "
else:
if(word in self.duzeltilecek):
metin += "<b style='color:#5dff00;font-size:16;'>"+self.tahminler[self.metnimiz.index(word)] + " "+"</b> "
else:
metin += "<b style='color:red;font-size:16;'>"+word+"</b> "
self.textEdit_3.setHtml(metin)
def metinYukle(self,MainWindow):
try:
self.dosyaURL = easygui.fileopenbox()
if self.dosyaURL=="" or self.dosyaURL.split(".")[-1]!="txt":
text = textract.process(self.dosyaURL)
print(text)
text=str(text)
text=text[2:-1]
self.metin = str(text)
try:
self.textEdit.setText(self.metin)
except Exception as e:
print ("error opened to closed port: " , str(e))
else:
with open(self.dosyaURL, "rb") as f:
self.metin = f.read().decode("UTF-8")
self.textEdit.setText(self.metin)
except:
pass
self.pushButton_5.setEnabled(True)
def metniDuzelt(self,MainWindow):
ret = CorrectSentences(self.textEdit.toPlainText())
metin = ""
self.listWidget.clear()
self.metnimiz = ret[0]
self.tahminler = ret[1]
self.duzeltilecek = []
self.toplam=[]
i=0
for word in self.metnimiz:
if(word==self.tahminler[self.metnimiz.index(word)]):
metin += word + " "
else:
metin += "<b style='color:red;font-size:16;'>"+word+"</b> "
self.duzeltilecek.append(word)
self.toplam.append(word)
self.listWidget.addItem(word)
self.listWidget.setCurrentRow(i)
i+=1
self.textEdit_2.setHtml(metin)
self.degisim(MainWindow)
self.pushButton_4.setEnabled(True)
print("deneme")
"""
uwu_degis=""
metin=uwu_degis
uwu_degis.split(" ")
try:
for i in range(0,1000):
if uwu_cumle !=uwu_degis[i]:
print(uwu_degis[i])
except:
print("olamdı")
"""
def kaydet(self,MainWindow):
metin = self.textEdit_3.toPlainText()
dosya = open(self.dosyaURL[:-4]+"_DUZELTILMIS.txt","w+")
dosya.write(metin)
dosya.close()
def iptal(self,MainWindow):
self.textEdit.setText("")
self.textEdit_2.setText("")
self.textEdit_3.setText("")
self.listWidget.clear()
self.metin = ""
self.dosyaURL = "adc.adc"
self.tahminler = []
self.duzeltilecek = []
self.toplam = []
self.metnimiz = []
self.pushButton_4.setEnabled(False)
def tumunuSec(self,MainWindow):
i=0
self.listWidget.clear()
for word in self.toplam:
self.listWidget.addItem(word)
self.listWidget.setCurrentRow(i)
i+=1
self.degisim(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "Metin Düzenleme"))
self.label.setText(_translate("MainWindow", "METIN"))
self.textEdit.setHtml(_translate("MainWindow", "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n"
"<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n"
"p, li { white-space: pre-wrap; }\n"
"</style></head><body style=\" font-family:\'MS Shell Dlg 2\'; font-size:14pt; font-style:normal;\">\n"
"</body></html>"))
self.textEdit_2.setHtml(_translate("MainWindow", "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n"
"<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n"
"p, li { white-space: pre-wrap; }\n"
"</style></head><body style=\" font-family:\'MS Shell Dlg 2\'; font-size:14pt; font-style:normal;\">\n"
"</body></html>"))
self.textEdit_3.setHtml(_translate("MainWindow", "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n"
"<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n"
"p, li { white-space: pre-wrap; }\n"
"</style></head><body style=\" font-family:\'MS Shell Dlg 2\'; font-size:14pt; font-style:normal;\">\n"
"</body></html>"))
self.pushButton.setText(_translate("MainWindow", "METİN YÜKLE"))
self.label_3.setText(_translate("MainWindow", "DÜZELTİLMİŞ METİN"))
self.listWidget.setSortingEnabled(False)
self.pushButton_3.setText(_translate("MainWindow", "TÜMÜNÜ SEÇ"))
self.label_2.setText(_translate("MainWindow", "HATALI YAZILMIŞ METİN"))
self.pushButton_2.setText(_translate("MainWindow", "METNİ DÜZENLE"))
self.pushButton_5.setText(_translate("MainWindow", "İPTAL"))
self.pushButton_4.setText(_translate("MainWindow", "KAYDET"))
if __name__ == "__main__":
import sys
app = QtWidgets.QApplication(sys.argv)
MainWindow = QtWidgets.QMainWindow()
ui = Ui_MainWindow()
ui.setupUi(MainWindow)
MainWindow.show()
sys.exit(app.exec_())
This code main gui code
And add model.py
from keras.models import Model
from keras.layers import Input, LSTM, Dense, Dropout
from keras import optimizers, metrics, backend as K
VAL_MAXLEN = 16
def truncated_acc(y_true, y_pred):
y_true = y_true[:, :VAL_MAXLEN, :]
y_pred = y_pred[:, :VAL_MAXLEN, :]
acc = metrics.categorical_accuracy(y_true, y_pred)
return K.mean(acc, axis=-1)
def truncated_loss(y_true, y_pred):
y_true = y_true[:, :VAL_MAXLEN, :]
y_pred = y_pred[:, :VAL_MAXLEN, :]
loss = K.categorical_crossentropy(
target=y_true, output=y_pred, from_logits=False)
return K.mean(loss, axis=-1)
def seq2seq(hidden_size, nb_input_chars, nb_target_chars):
encoder_inputs = Input(shape=(None, nb_input_chars),
name='encoder_data')
encoder_lstm = LSTM(hidden_size, recurrent_dropout=0.2,
return_sequences=True, return_state=False,
name='encoder_lstm_1')
encoder_outputs = encoder_lstm(encoder_inputs)
encoder_lstm = LSTM(hidden_size, recurrent_dropout=0.2,
return_sequences=False, return_state=True,
name='encoder_lstm_2')
encoder_outputs, state_h, state_c = encoder_lstm(encoder_outputs)
encoder_states = [state_h, state_c]
decoder_inputs = Input(shape=(None, nb_target_chars),
name='decoder_data')
decoder_lstm = LSTM(hidden_size, dropout=0.2, return_sequences=True,
return_state=True, name='decoder_lstm')
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
initial_state=encoder_states)
decoder_softmax = Dense(nb_target_chars, activation='softmax',
name='decoder_softmax')
decoder_outputs = decoder_softmax(decoder_outputs)
model = Model(inputs=[encoder_inputs, decoder_inputs],
outputs=decoder_outputs)
adam = optimizers.Adam(lr=0.001, decay=0.0)
model.compile(optimizer=adam, loss='categorical_crossentropy',
metrics=['accuracy', truncated_acc, truncated_loss])
encoder_model = Model(inputs=encoder_inputs, outputs=encoder_states)
decoder_state_input_h = Input(shape=(hidden_size,))
decoder_state_input_c = Input(shape=(hidden_size,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder_lstm(
decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_softmax(decoder_outputs)
decoder_model = Model(inputs=[decoder_inputs] + decoder_states_inputs,
outputs=[decoder_outputs] + decoder_states)
return model, encoder_model, decoder_model