1、FFT部分
#include <M5StickCPlus.h>
#pragma GCC optimize ("O3")
#include <driver/i2s.h>
#include <math.h>
#include <string.h>
#include <stdio.h>
#include "dywapitchtrack.h"
#include "fix_fft.h"
#define PIN_CLK 0
#define PIN_DATA 34
#define SAMPLES 1024 // Must be a power of 2
#define READ_LEN (2 * SAMPLES)
//M5StickCPlus的宽度
#define TFT_WIDTH 240
//M5StickCPlus的高度
#define TFT_HEIGHT 135
#define BANDS 8
#define BANDS_WIDTH ( TFT_WIDTH / BANDS )
#define BANDS_PADDING 8
#define BAR_WIDTH ( BANDS_WIDTH - BANDS_PADDING )
#define NOISE_FLOOR 1
#define AMPLIFIER (1 << 2)
#define MAGNIFY 3
#define RSHIFT 13
#define RSHIFT2 1
#define OSC_NOISEFLOOR 100
#define OSC_SAMPLES DYWAPT_SAMPLESIZE
#define OSC_SKIPCOUNT (OSC_SAMPLES/SAMPLES)
#define OSC_EXTRASKIP 0
//#define MAXBUFSIZE (2*SAMPLES)
#define MAXBUFSIZE OSC_SAMPLES
struct eqBand {
const char *freqname;
int peak;
int lastpeak;
uint16_t lastval;
};
enum : uint8_t {
ModeSpectrumBars,
ModeOscilloscope,
ModeTuner,
ModeCount,
};
static uint8_t runmode = 0;
static volatile bool semaphore = false;
static volatile bool needinit = true;
static eqBand audiospectrum[BANDS] = {
// freqname,peak,lastpeak,lastval,
// 频率名称,峰值,前一个峰值
{ ".1k", 0 },
{ ".2k", 0 },
{ ".5k", 0 },
{ " 1k", 0 },
{ " 2k", 0 },
{ " 4k", 0 },
{ " 8k", 0 },
{ "16k", 0 }
};
static int vTemp[2][MAXBUFSIZE];
static uint8_t curbuf = 0;
static uint16_t colormap[TFT_HEIGHT];//color palette for the band meter(pre-fill in setup)
static dywapitchtracker pitchTracker;
static TFT_eSprite sprite(&M5.Lcd);
static int bufposcount = 0;
static const char *notestr[12] = {
"C ", "C#", "D ", "D#", "E ", "F ", "F#", "G ", "G#", "A ", "A#", "B "
};
static uint16_t oscbuf[2][OSC_SAMPLES];
#if OSC_EXTRASKIP > 0
static int skipcount = 0;
#endif
void i2sInit(){
i2s_config_t i2s_config = {
.mode = (i2s_mode_t)(I2S_MODE_MASTER | I2S_MODE_RX | I2S_MODE_PDM),
.sample_rate = DYWAPT_SAMPLERATE,
.bits_per_sample = I2S_BITS_PER_SAMPLE_16BIT, //is fixed at 12bit,stereo,MSB
.channel_format = I2S_CHANNEL_FMT_ALL_RIGHT,
.communication_format = I2S_COMM_FORMAT_I2S,
.intr_alloc_flags = ESP_INTR_FLAG_LEVEL1,
.dma_buf_count = 2,
.dma_buf_len = 128,
};
i2s_pin_config_t pin_config;
pin_config.bck_io_num = I2S_PIN_NO_CHANGE;
pin_config.ws_io_num = PIN_CLK;
pin_config.data_out_num = I2S_PIN_NO_CHANGE;
pin_config.data_in_num = PIN_DATA;
i2s_driver_install(I2S_NUM_0, &i2s_config, 0, NULL);
i2s_set_pin(I2S_NUM_0, &pin_config);
i2s_set_clk(I2S_NUM_0, DYWAPT_SAMPLERATE, I2S_BITS_PER_SAMPLE_16BIT, I2S_CHANNEL_MONO);
}
void setup() {
M5.begin();
setCpuFrequencyMhz(80);
M5.Lcd.setRotation(1);
M5.Lcd.fillScreen(BLACK);
sprite.createSprite(TFT_WIDTH, TFT_HEIGHT);
sprite.setTextSize(1);
i2sInit();
for(uint8_t i=0;i<TFT_HEIGHT;i++) {
//colormap[i] = M5.Lcd.color565(TFT_HEIGHT-i*.5,i*1.1,0); //RGB
//colormap[i] = M5.Lcd.color565(TFT_HEIGHT-i*4.4,i*2.5,0);//RGB:rev macsbug
float r = TFT_HEIGHT - i;
float g = i;
float mag = (r > g)? 255./r : 255./g;
r *= mag;
g *= mag;
colormap[i] = M5.Lcd.color565((uint8_t)r,(uint8_t)g,0); // Modified by KKQ-KKQ
}
int core = 1 - xPortGetCoreID();
xTaskCreatePinnedToCore(looptask,"calctask",32768,NULL,1,NULL,core);
}
void initMode() {
M5.Lcd.fillRect(0, 0,
TFT_WIDTH, TFT_HEIGHT, BLACK);
switch (runmode) {
//频谱模式
case ModeSpectrumBars:
M5.Lcd.setTextSize(1);
M5.Lcd.setTextColor(LIGHTGREY);
for (byte band = 0; band < BANDS; band++) {
M5.Lcd.setCursor(BANDS_WIDTH*band + 2, 0);
//哦哦哦,我明白了,这其实就只是在打印audiospectrum那个结构体里的频率名称而已
M5.Lcd.print(audiospectrum[band].freqname);
}
break;
case ModeOscilloscope:
{
sprite.setTextColor(GREEN);
dywapitch_inittracking(&pitchTracker);
}
break;
case ModeTuner:
sprite.setTextSize(1);
break;
}
}
//好嘞,这里是实际的频谱的渲染程序部分
void showSpectrumBars(){
int *vTemp_ = vTemp[curbuf^1];
//对mic的数据应用了一个什么窗口,然后计算了一下fft,这里才开始fft
Fixed15FFT::apply_window(vTemp_);
Fixed15FFT::calc_fft(vTemp_, vTemp_ + SAMPLES);
int values[BANDS] = {};
for (int i = 2; i < (SAMPLES/2); i++){
// Don't use sample 0 and only first SAMPLES/2 are usable.
// Each array element represents a frequency and its value the amplitude.
int ampsq = vTemp_[i] * vTemp_[i] + vTemp_[i + SAMPLES] * vTemp_[i + SAMPLES];
if (ampsq > NOISE_FLOOR) {
byte bandNum = getBand(i);
if(bandNum != 8 && ampsq > values[bandNum]) {
values[bandNum] = ampsq;
}
}
}
for (byte band = 0; band < BANDS; band++) {
int log2 = FIX_LOG2<15-RSHIFT>(values[band]);
if (log2 > -AMPLIFIER) {
displayBand(band, ((log2 + AMPLIFIER) * MAGNIFY) >> RSHIFT2);
} else {
displayBand(band, 0);
}
}
//long vnow = millis();
for (byte band = 0; band < BANDS; band++) {
if (audiospectrum[band].peak > 0) {
audiospectrum[band].peak -= 2;
}
if(audiospectrum[band].peak <= 0) {
audiospectrum[band].peak = 0;
}
// only draw if peak changed
if(audiospectrum[band].lastpeak != audiospectrum[band].peak) {
// delete last peak
uint16_t hpos = BANDS_WIDTH*band + (BANDS_PADDING/2);
M5.Lcd.drawFastHLine(hpos,TFT_HEIGHT-audiospectrum[band].lastpeak,BAR_WIDTH,BLACK);
audiospectrum[band].lastpeak = audiospectrum[band].peak;
uint16_t ypos = TFT_HEIGHT - audiospectrum[band].peak;
M5.Lcd.drawFastHLine(hpos, ypos,
BAR_WIDTH, colormap[ypos]);
}
}
// struct eqBand {
// const char *freqname;
// int peak;
// int lastpeak;
// uint16_t lastval;
// };
//试着打印一下
// Serial.print(audiospectrum[0].peak);
// Serial.print(",");
// Serial.print(audiospectrum[1].peak);
// Serial.print(",");
// Serial.print(audiospectrum[2].peak);
// Serial.print(",");
// Serial.print(audiospectrum[3].peak);
// Serial.print(",");
// Serial.print(audiospectrum[4].peak);
// Serial.print(",");
// Serial.print(audiospectrum[5].peak);
// Serial.print(",");
// Serial.print(audiospectrum[6].peak);
// Serial.print(",");
// Serial.print(audiospectrum[7].peak);
// Serial.println();
int dy10=audiospectrum[1].lastpeak-audiospectrum[0].lastpeak;
int dy21=audiospectrum[2].lastpeak-audiospectrum[1].lastpeak;
int dy32=audiospectrum[3].lastpeak-audiospectrum[2].lastpeak;
int dy43=audiospectrum[4].lastpeak-audiospectrum[3].lastpeak;
int dy54=audiospectrum[5].lastpeak-audiospectrum[4].lastpeak;
int dy65=audiospectrum[6].lastpeak-audiospectrum[5].lastpeak;
int dy76=audiospectrum[7].lastpeak-audiospectrum[6].lastpeak;
float dddx = 0.455;
//计算出来了7段斜率
float slope01 = dy10 / dddx;
slope01 = fabs(slope01) < 5? 0 : slope01;
float slope12 = dy21 / dddx;
slope12 = fabs(slope12) < 5? 0 : slope12;
float slope23 = dy32 / dddx;
slope23 = fabs(slope23) < 5? 0 : slope23;
float slope34 = dy43 / dddx;
slope34 = fabs(slope34) < 5? 0 : slope34;
float slope45 = dy54 / dddx;
slope45 = fabs(slope45) < 5? 0 : slope45;
float slope56 = dy65 / dddx;
slope56 = fabs(slope56) < 5? 0 : slope56;
float slope67 = dy76 / dddx;
slope67 = fabs(slope67) < 5? 0 : slope67;
Serial.print(slope01);
Serial.print(",");
Serial.print(slope12);
Serial.print(",");
Serial.print(slope23);
Serial.print(",");
Serial.print(slope34);
Serial.print(",");
Serial.print(slope45);
Serial.print(",");
Serial.print(slope56);
Serial.print(",");
Serial.print(slope67);
Serial.println();
}
void displayBand(int band, int dsize){
uint16_t hpos = BANDS_WIDTH*band + (BANDS_PADDING/2);
if (dsize < 0) dsize = 0;
if(dsize>TFT_HEIGHT-10) {
dsize = TFT_HEIGHT-10; // leave some hspace for text
}
if(dsize < audiospectrum[band].lastval) {
// lower value, delete some lines
M5.Lcd.fillRect(hpos, TFT_HEIGHT-audiospectrum[band].lastval,
BAR_WIDTH, audiospectrum[band].lastval - dsize,BLACK);
}
for (int s = 0; s <= dsize; s=s+4){
uint16_t ypos = TFT_HEIGHT - s;
M5.Lcd.drawFastHLine(hpos, ypos, BAR_WIDTH, colormap[ypos]);
}
if (dsize > audiospectrum[band].peak){audiospectrum[band].peak = dsize;}
audiospectrum[band].lastval = dsize;
}
byte getBand(int i) {
if (i >= 2 && i < 4 ) return 0; // 125Hz
if (i >= 4 && i < 8 ) return 1; // 250Hz
if (i >= 8 && i < 16 ) return 2; // 500Hz
if (i >= 16 && i < 32 ) return 3; // 1000Hz
if (i >= 32 && i < 64 ) return 4; // 2000Hz
if (i >= 64 && i < 128) return 5; // 4000Hz
if (i >= 128 && i < 256) return 6; // 8000Hz
if (i >= 256 && i < 512) return 7; // 16000Hz
return 8;
}
float calcNumSamples(float f) {
if (f == 0.0) return SAMPLES/2;
float s = (float)(DYWAPT_SAMPLERATE * 2) / f;
if (s > (float)OSC_SAMPLES) {
do {
s *= 0.5;
} while (s > (float)OSC_SAMPLES);
}
return s;
}
void showFreq(float freq) {
if (freq > 0) {
char strbuf[16];
sprintf(strbuf, "%8.2fHz", freq);
sprite.drawString(strbuf, 0, 0, 1);
float fnote = log2(freq)*12 - 36.376316562f;
int note = fnote + 0.5f;
if (note >= 0) {
sprite.setCursor(TFT_WIDTH/2 - 4, 0);
sprite.print(notestr[note % 12]);
sprite.print(note / 12 - 1);
float cent = (fnote - note) * 100;
sprintf(strbuf, "%.1fcents", cent);
sprite.drawRightString(strbuf, TFT_WIDTH, 0, 1);
}
}
}
void showOscilloscope()
{
uint16_t i,j;
uint16_t *oscbuf_ = oscbuf[curbuf^1];
int *vTemp_ = vTemp[curbuf ^ 1];
float freq = dywapitch_computepitch(&pitchTracker, vTemp_);
#if OSC_EXTRASKIP > 0
if (skipcount < OSC_EXTRASKIP) {
++skipcount;
return;
}
skipcount = 0;
#endif
uint16_t s = calcNumSamples(freq);
float mx = (float)TFT_WIDTH / s;
float my;
uint16_t maxV = 0;
uint16_t minV = 65535;
uint16_t offset = 0;
for (i = 0; i < s; ++i) {
if (maxV < oscbuf_[i]) maxV = oscbuf_[i];
if (minV > oscbuf_[i]) {
minV = oscbuf_[i];
if (i + s <= OSC_SAMPLES) offset = i;
}
}
if (maxV - minV > OSC_NOISEFLOOR) {
my = (float)(TFT_HEIGHT-10) / (maxV - minV);
}
else {
my = (float)(TFT_HEIGHT-10) / OSC_NOISEFLOOR;
minV = (((int)maxV + (int)minV) >> 1) - OSC_NOISEFLOOR/2;
}
sprite.fillSprite(BLACK);
uint16_t y = TFT_HEIGHT - (oscbuf_[offset] - minV) * my;
for (i = 1; i < s; ++i) {
uint16_t y2 = TFT_HEIGHT - (oscbuf_[offset + i] - minV) * my;
sprite.drawLine((uint16_t)((i-1) * mx), y,
(uint16_t)(i * mx), y2, LIGHTGREY);
y = y2;
}
showFreq(freq);
sprite.pushSprite(0,0);
}
void showTuner() {
int *vTemp_ = vTemp[curbuf ^ 1];
float freq = dywapitch_computepitch(&pitchTracker, vTemp_);
float fnote;
int note;
if (freq > 0) {
fnote = log2(freq)*12 - 36.376316562;
note = fnote + 0.5;
}
else {
note = -1;
}
uint32_t bgcolor, fgcolor;
if (note >= 0) {
float cent = (fnote - note) * 100;
if (abs(cent) < 2.) {
bgcolor = GREEN;
fgcolor = BLACK;
}
else {
bgcolor = DARKGREY;
fgcolor = BLACK;
}
sprite.fillSprite(bgcolor);
sprite.fillRect(0, 0, TFT_WIDTH, TFT_HEIGHT, bgcolor);
sprite.setTextColor(fgcolor);
sprite.drawRect(2, 36, TFT_WIDTH-3, TFT_HEIGHT-40, fgcolor);
sprite.drawRect(TFT_WIDTH/2 + 1, 36, 1, TFT_HEIGHT - 40, fgcolor);
sprite.fillCircle(((float)TFT_WIDTH/2 + 1) + cent * ((float)(TFT_WIDTH-3)/100), (TFT_HEIGHT+34)/2, 5, fgcolor);
char strbuf[8];
sprintf(strbuf, "%s%d", notestr[note % 12], note / 12 - 1);
sprite.drawCentreString(strbuf, TFT_WIDTH/2, 3, 4);
}
else {
sprite.fillSprite(DARKGREY);
sprite.drawRect(2, 36, TFT_WIDTH-3, TFT_HEIGHT-40, BLACK);
sprite.drawLine(TFT_WIDTH/2 + 1, 36, TFT_WIDTH/2 + 1, TFT_HEIGHT - 4, BLACK);
}
sprite.pushSprite(0,0);
}
void looptask(void *) {
while (1) {
if (needinit) {
initMode();
needinit = false;
}
if (semaphore) {
switch(runmode) {
case ModeSpectrumBars:
//这里是实际的进入函数
showSpectrumBars();
break;
case ModeOscilloscope:
showOscilloscope();
break;
case ModeTuner:
showTuner();
break;
}
semaphore = false;
}
else {
vTaskDelay(10);
}
}
}
void loop() {
M5.update();
if (M5.BtnA.wasReleased()) {
++runmode;
if (runmode >= ModeCount) runmode = 0;
bufposcount = 0;
needinit = true;;
}
uint16_t i,j;
j = bufposcount * SAMPLES;
uint16_t *adcBuffer = &oscbuf[curbuf][j];
size_t bytesread;
i2s_read(I2S_NUM_0,(char*)adcBuffer,READ_LEN,&bytesread,portMAX_DELAY);
int32_t dc = 0;
for (int i = 0; i < SAMPLES; ++i) {
dc += adcBuffer[i];
}
dc /= SAMPLES;
switch(runmode) {
case ModeSpectrumBars:
for (int i = 0; i < SAMPLES; ++i) {
vTemp[curbuf][i] = (int)adcBuffer[i] - dc;
vTemp[curbuf][i + SAMPLES] = 0;
}
curbuf ^= 1;
semaphore = true;
break;
case ModeOscilloscope:
for (i = 0; i < SAMPLES; ++i) {
vTemp[curbuf][i + j] = (int)adcBuffer[i] - dc;
}
if (++bufposcount >= OSC_SKIPCOUNT) {
bufposcount = 0;
curbuf ^= 1;
semaphore = true;
}
break;
case ModeTuner:
j = bufposcount * SAMPLES;
for (i = 0; i < SAMPLES; ++i) {
vTemp[curbuf][i + j] = (int)adcBuffer[i] - dc;
}
if (++bufposcount >= OSC_SKIPCOUNT) {
bufposcount = 0;
curbuf ^= 1;
semaphore = true;
}
break;
}
}https://github.com/KKQ-KKQ/m5stickc-audiospectrum/tree/master
用的是这个,但是有修改:
int dy10=audiospectrum[1].lastpeak-audiospectrum[0].lastpeak;
int dy21=audiospectrum[2].lastpeak-audiospectrum[1].lastpeak;
int dy32=audiospectrum[3].lastpeak-audiospectrum[2].lastpeak;
int dy43=audiospectrum[4].lastpeak-audiospectrum[3].lastpeak;
int dy54=audiospectrum[5].lastpeak-audiospectrum[4].lastpeak;
int dy65=audiospectrum[6].lastpeak-audiospectrum[5].lastpeak;
int dy76=audiospectrum[7].lastpeak-audiospectrum[6].lastpeak;
float dddx = 0.455;
//计算出来了7段斜率
float slope01 = dy10 / dddx;
slope01 = fabs(slope01) < 5? 0 : slope01;
float slope12 = dy21 / dddx;
slope12 = fabs(slope12) < 5? 0 : slope12;
float slope23 = dy32 / dddx;
slope23 = fabs(slope23) < 5? 0 : slope23;
float slope34 = dy43 / dddx;
slope34 = fabs(slope34) < 5? 0 : slope34;
float slope45 = dy54 / dddx;
slope45 = fabs(slope45) < 5? 0 : slope45;
float slope56 = dy65 / dddx;
slope56 = fabs(slope56) < 5? 0 : slope56;
float slope67 = dy76 / dddx;
slope67 = fabs(slope67) < 5? 0 : slope67;
Serial.print(slope01);
Serial.print(",");
Serial.print(slope12);
Serial.print(",");
Serial.print(slope23);
Serial.print(",");
Serial.print(slope34);
Serial.print(",");
Serial.print(slope45);
Serial.print(",");
Serial.print(slope56);
Serial.print(",");
Serial.print(slope67);
Serial.println();2、PC机器接收串口并转发至mqtt部分(同时也是保存csv的部分):
import serial
import pandas as pd
import paho.mqtt.client as mqtt
import json
try:
# 创建串口对象
ser = serial.Serial('COM5', 115200)
except serial.SerialException as e:
print(f"Error opening COM5: {e}")
export_csv_filename = 0
def on_connect(client, userdata, flags, rc):
if rc == 0:
print("Connected to MQTT broker")
else:
print("Connection failed with code", rc)
client = mqtt.Client()
client.on_connect = on_connect
client.connect("192.168.50.232", 1883, 60)
while True:
loops_nums = 0
# 创建一个空的 DataFrame
df = pd.DataFrame(columns=["slope01","slope12","slope23","slope34","slope45","slope56","slope67"])
while loops_nums< 120:
try:
# 读取数据
data = ser.readline()
data_array = []
data_dict = {}
if data:
decodeLine = data.decode('utf-8').strip()
#print(decodeLine)
values = decodeLine.split(',')
#将合法的数据解析为一个数组
for value in values:
try:
#print(value)
f_temp = float(value)
data_array.append(f_temp)
except ValueError:
pass
# 将数组转换为字典
data_dict = {
"slope01": data_array[0],
"slope12": data_array[1],
"slope23": data_array[2],
"slope34": data_array[3],
"slope45": data_array[4],
"slope56": data_array[5],
"slope67": data_array[6],
}
#print("打印data_dict啊")
#print(data_dict)
# 使用 concat 函数添加
new_df = pd.DataFrame([data_dict],index=[loops_nums])
df = pd.concat([df, new_df])
except Exception as e:
print(f"programe error: {e}")
pass
loops_nums=loops_nums+1
print("打印df啊")
print(df)
# export_csv_filename = export_csv_filename +1
# df.to_csv('traindata_'+ str(export_csv_filename) +'.csv', index=False) # 不保存索引
# 序列化数据框为 JSON 字符串
serialized_df = json.dumps(df.to_dict())
client.publish("lemon_ken_mic", serialized_df)
# 关闭串口
ser.close()
print("======================关闭串口====================")
# python -m venv venv
# .\venv\Scripts\Activate.ps1
# pip install pyserial
# pip install paho-mqtt
# pip install pandas
# python mqtt_send.py3、训练以及评估部分:
import torch
import pandas as pd
# 定义神经网络类
class AudioClassifier(torch.nn.Module):
def __init__(self, batch_size, in_features):
super(AudioClassifier, self).__init__()
self.layer1 = torch.nn.Linear(in_features, 128)
self.additional_layer = torch.nn.Linear(128, 64)
self.layer2 = torch.nn.Linear(64, 32)
self.layer3 = torch.nn.Linear(32, 2) # 假设两类
def forward(self, x):
x = torch.relu(self.layer1(x))
x = torch.relu(self.additional_layer(x)) # 使用新增层
x = torch.relu(self.layer2(x))
x = self.layer3(x)
return x
# 读取数据并处理
def load_data(file_paths):
data = []
labels = []
for path in file_paths:
df = pd.read_csv(path, skiprows=1) # 跳过第一行
features = df.values
data.append(features)
# 根据文件名确定标签
if "c0" in path:
labels.append(0)
elif "c1" in path:
labels.append(1)
return torch.tensor(data), torch.tensor(labels)
# 示例用法
train_file_paths = ["traindata_c0_1.csv","traindata_c0_2.csv","traindata_c0_3.csv","traindata_c0_4.csv"
,"traindata_c0_5.csv","traindata_c0_6.csv","traindata_c0_11.csv","traindata_c0_12.csv",
"traindata_c0_13.csv","traindata_c0_14.csv","traindata_c0_15.csv","traindata_c0_16.csv",
"traindata_c0_17.csv","traindata_c0_18.csv","traindata_c0_19.csv","traindata_c0_20.csv",
"traindata_c0_21.csv","traindata_c0_22.csv",
"traindata_c1_1.csv","traindata_c1_2.csv","traindata_c1_3.csv","traindata_c1_4.csv"
,"traindata_c1_5.csv","traindata_c1_6.csv","traindata_c1_11.csv","traindata_c1_12.csv"
,"traindata_c1_13.csv","traindata_c1_14.csv","traindata_c1_15.csv","traindata_c1_16.csv"
,"traindata_c1_17.csv","traindata_c1_18.csv","traindata_c1_19.csv","traindata_c1_20.csv"]
train_data, train_labels = load_data(train_file_paths)
print("train_data.dtype:"+str(train_data.dtype))
print("train_labels.dtype:"+str(train_labels.dtype))
# 创建神经网络
# 对于 torch.nn.Linear,它通常期望输入的是一个二维张量,
# 其中第一维可以理解为批量大小(在你的例子中就是 20 个样本),
# 第二维是每个样本的特征数量。在你这个情况中,你需要先将每个样本(119 行×7 列)展平成一个一维向量,
# 这样每个样本就变成了一个长度为 119×7=833 的向量。然后将这 20 个展平后的样本组合成一个二维张量,
# 其形状就是 (20, 833),将这个二维张量作为输入传递给 torch.nn.Linear。
# 例如,如果有一个输入张量x,其形状为[batch_size, in_features],
# 那么经过Linear层的变换后,输出张量y的形状将为[batch_size, out_features]。
num_of_samples = train_data.size(0)
num_of_rows = train_data[0].shape[0]
num_of_cols = train_data[0].shape[1]
print("num_of_samples:")
print(num_of_samples)
print("num_of_rows:")
print(num_of_rows)
print("num_of_cols:")
print(num_of_cols)
model = AudioClassifier(num_of_samples,num_of_rows*num_of_cols)
# 要保持原始张量的维度顺序不变,可以使用permute方法进行维度变换。
# permute方法接受一个参数,用于指定新的维度顺序。例如,如果你想将最后两维展平,可以使用以下代码:
# 上面说法根本就是错的,需要用到reshape方法,把张量reshape成一个[batch_size, in_features]的东西
new_train_data_tensor = train_data.reshape(num_of_samples, 119*7)
print("new_train_data_tensor shape:")
print(new_train_data_tensor.shape) # 输出新张量的形状
# 定义损失函数和优化器
loss_func = torch.nn.CrossEntropyLoss()
criterion = torch.nn.BCELoss() # 使用默认参数
optimizer = torch.optim.Adam(model.parameters())
print("=========tow input shapes=============")
print(new_train_data_tensor.shape, train_labels.shape)
print("=========tow input shapes=============")
# 训练循环
for epoch in range(200):
outputs = model(new_train_data_tensor.to(torch.float32))
#outputs = torch.sigmoid(outputs) # 应用sigmoid函数确保输出在0和1之间
loss = loss_func(outputs, train_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("==========outputs.shape:============")
print(outputs.shape)
# print("==========outputs:============")
# print(outputs)
# ==========outputs:============
# tensor([[ 10.3342, -3.6685],
# [ 13.9501, -0.4912],
# [ 15.0824, -1.4321],
# [ 16.4281, -1.1477],
# [ 11.7916, -4.5559],
# [ 15.0455, -5.4270],
# [ 10.3306, -8.4917],
# [ 9.7162, -8.0242],
# [ 11.2048, -7.3689],
# [ 6.1282, -6.2711],
# [ -3.8303, 7.0840],
# [ -4.3532, 17.0647],
# [-15.7246, 44.2592],
# [ -7.8866, 12.5852],
# [ -8.0347, 10.8943],
# [ -1.7064, 9.8779],
# [ -2.2770, 11.1365],
# [ -3.4305, 18.9242],
# [ -3.7132, 28.6930],
# [-17.9371, 42.6237]], grad_fn=<AddmmBackward0>)
# 所以说这种输出的含义就是,它认为,属于1类别的概率是正数,2类别的概率甚至是个负数是吧...明白了;
# 可以这样理解。在这个输出中,每个子张量的第一个元素(正数)可以被视为模型预测该样本属于第一类别的
# 概率估计值,第二个元素(可能是负数)则是模型预测该样本属于第二类别的概率估计值。
# 需要注意的是,这里的数值本身并不一定直接对应严格意义上的概率,因为模型的输出可能
# 没有经过专门的归一化处理以确保和为 1,但通常可以相对地理解为表示某种倾向或可能性的大小。
# 而且负数在这种情况下也只是表示相对的大小关系,并不意味着真正的负概率。
# 在实际应用中,一般会通过合适的方法将这些值转换为更符合概率解释的形式。
print("==========probabilities:============")
probabilities = torch.softmax(outputs, dim=1)
print(probabilities)
# tensor([[1.0000e+00, 4.4273e-09],
# [1.0000e+00, 1.5222e-12],
# [1.0000e+00, 3.5589e-12],
# [1.0000e+00, 3.0917e-11],
# [1.0000e+00, 5.3411e-10],
# [1.0000e+00, 1.3765e-08],
# [1.0000e+00, 1.4835e-08],
# [1.0000e+00, 4.2492e-07],
# [1.0000e+00, 5.1067e-07],
# [9.9997e-01, 2.8726e-05],
# [5.4728e-06, 9.9999e-01],
# [5.5860e-08, 1.0000e+00],
# [3.0678e-19, 1.0000e+00],
# [9.8549e-08, 1.0000e+00],
# [9.0871e-08, 1.0000e+00],
# [4.0230e-06, 1.0000e+00],
# [4.5366e-06, 1.0000e+00],
# [9.9386e-07, 1.0000e+00],
# [5.5348e-12, 1.0000e+00],
# [1.3492e-18, 1.0000e+00]], grad_fn=<SoftmaxBackward0>)
#====================================================================================================
# 模型评估
eval_file_paths = ["traindata_c0_7.csv","traindata_c0_8.csv","traindata_c0_9.csv","traindata_c0_10.csv",
"traindata_c1_7.csv","traindata_c1_8.csv","traindata_c1_9.csv","traindata_c1_10.csv"
,"traindata_c1_21.csv","traindata_c1_22.csv","traindata_c1_23.csv"]
eval_data, eval_labels = load_data(eval_file_paths)
# 要保持原始张量的维度顺序不变,可以使用permute方法进行维度变换。
# permute方法接受一个参数,用于指定新的维度顺序。例如,如果你想将最后两维展平,可以使用以下代码:
# 上面说法根本就是错的,需要用到reshape方法,把张量reshape成一个[batch_size, in_features]的东西
num_of_eval_samples = eval_data.size(0)
new_eval_data_tensor = eval_data.reshape(num_of_eval_samples, 119*7)
print("new_eval_data_tensor shape:")
print(new_eval_data_tensor.shape) # 输出新张量的形状
# 评估模型
with torch.no_grad():
test_outputs = model(new_eval_data_tensor.to(torch.float32))
predicted_labels = torch.argmax(test_outputs, dim=1)
accuracy = (predicted_labels == eval_labels).sum().item() / eval_labels.size(0)
print("Accuracy:", accuracy)
print("==========评估阶段的predicted_labels:============")
print(predicted_labels)
print("==========评估阶段的eval_labels:============")
print(eval_labels)
print("==========评估阶段的probabilities:============")
probabilities = torch.softmax(test_outputs, dim=1)
print(probabilities)
# 假设已经训练好的模型为 model
torch.save(model.state_dict(), 'odel_weights.pth')4、推理部分
import torch
import pandas as pd
import paho.mqtt.client as mqtt
import json
import tkinter as tk
# 定义神经网络类
class AudioClassifier(torch.nn.Module):
def __init__(self, batch_size, in_features):
super(AudioClassifier, self).__init__()
self.layer1 = torch.nn.Linear(in_features, 128)
self.additional_layer = torch.nn.Linear(128, 64)
self.layer2 = torch.nn.Linear(64, 32)
self.layer3 = torch.nn.Linear(32, 2) # 假设两类
def forward(self, x):
x = torch.relu(self.layer1(x))
x = torch.relu(self.additional_layer(x)) # 使用新增层
x = torch.relu(self.layer2(x))
x = self.layer3(x)
return x
model = AudioClassifier(1,119*7) # 定义模型结构
model.load_state_dict(torch.load('odel_weights.pth'))
model.eval() # 设置为评估模式
# # 输入数据
# input_data = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
# # 使用模型进行预测
# output = model(input_data)
# 对输出进行处理或查看结果
# print(output)
def update_circle(probabilities):
print("====update_circle我拿到了一个啥:====")
print(probabilities) # 输出新张量的形状
if probabilities[0] > 0.8:
canvas.itemconfig(circle, fill='red')
else:
canvas.itemconfig(circle, fill='lightblue')
root = tk.Tk()
canvas = tk.Canvas(root, width=400, height=400)
canvas.pack()
circle = canvas.create_oval(100, 100, 300, 300, fill='lightblue')
def on_message(client, userdata, msg):
received_data = json.loads(msg.payload.decode())
df = pd.DataFrame(received_data)
#df.values天然就已经去掉表头了,pandas里面的这个语法就是在取值
print(df.values)
values119 = df.values[:-1]
# 将数据转换为张量
tensor = torch.tensor(values119)
print("====reviced tensor shape:====")
print(tensor.shape) # 输出新张量的形状
new_predict_tensor = tensor.reshape(119*7)
# 使用模型进行预测
output = model(new_predict_tensor.to(torch.float32))
#对输出进行处理或查看结果
print(output)
print("====output tensor shape:====")
print(output.shape) # 输出新张量的形状
print("==========probabilities:============")
probabilities = torch.softmax(output, dim=0)
print(probabilities)
update_circle(probabilities)
client = mqtt.Client()
client.on_message = on_message
client.connect("192.168.50.232", 1883,65535)
client.subscribe("lemon_ken_mic")
client.loop_start()
root.mainloop()
# python -m venv venv
# .\venv\Scripts\Activate.ps1
# pip install pyserial
# python main.py
# pip install pandas
# pip install PyQt5https://github.com/lemonhall/esp32_pytorch_audio_nn/tree/main