In [19]:
#imports
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
In [20]:
data = pd.read_csv('car_price.csv')
data.head(10)
Out[20]:
car_ID symboling CarName fueltype aspiration doornumber carbody drivewheel enginelocation wheelbase ... enginesize fuelsystem boreratio stroke compressionratio horsepower peakrpm citympg highwaympg price
0 1 3 alfa-romero giulia gas std two convertible rwd front 88.6 ... 130 mpfi 3.47 2.68 9.0 111 5000 21 27 13495.000
1 2 3 alfa-romero stelvio gas std two convertible rwd front 88.6 ... 130 mpfi 3.47 2.68 9.0 111 5000 21 27 16500.000
2 3 1 alfa-romero Quadrifoglio gas std two hatchback rwd front 94.5 ... 152 mpfi 2.68 3.47 9.0 154 5000 19 26 16500.000
3 4 2 audi 100 ls gas std four sedan fwd front 99.8 ... 109 mpfi 3.19 3.40 10.0 102 5500 24 30 13950.000
4 5 2 audi 100ls gas std four sedan 4wd front 99.4 ... 136 mpfi 3.19 3.40 8.0 115 5500 18 22 17450.000
5 6 2 audi fox gas std two sedan fwd front 99.8 ... 136 mpfi 3.19 3.40 8.5 110 5500 19 25 15250.000
6 7 1 audi 100ls gas std four sedan fwd front 105.8 ... 136 mpfi 3.19 3.40 8.5 110 5500 19 25 17710.000
7 8 1 audi 5000 gas std four wagon fwd front 105.8 ... 136 mpfi 3.19 3.40 8.5 110 5500 19 25 18920.000
8 9 1 audi 4000 gas turbo four sedan fwd front 105.8 ... 131 mpfi 3.13 3.40 8.3 140 5500 17 20 23875.000
9 10 0 audi 5000s (diesel) gas turbo two hatchback 4wd front 99.5 ... 131 mpfi 3.13 3.40 7.0 160 5500 16 22 17859.167

10 rows × 26 columns

1. 探索数据特性¶

In [21]:
print("Rows: ",data.shape[0])
print("Columns: ",data.shape[1])
Rows:  205
Columns:  26
In [22]:
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 205 entries, 0 to 204
Data columns (total 26 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   car_ID            205 non-null    int64  
 1   symboling         205 non-null    int64  
 2   CarName           205 non-null    object 
 3   fueltype          205 non-null    object 
 4   aspiration        205 non-null    object 
 5   doornumber        205 non-null    object 
 6   carbody           205 non-null    object 
 7   drivewheel        205 non-null    object 
 8   enginelocation    205 non-null    object 
 9   wheelbase         205 non-null    float64
 10  carlength         205 non-null    float64
 11  carwidth          205 non-null    float64
 12  carheight         205 non-null    float64
 13  curbweight        205 non-null    int64  
 14  enginetype        205 non-null    object 
 15  cylindernumber    205 non-null    object 
 16  enginesize        205 non-null    int64  
 17  fuelsystem        205 non-null    object 
 18  boreratio         205 non-null    float64
 19  stroke            205 non-null    float64
 20  compressionratio  205 non-null    float64
 21  horsepower        205 non-null    int64  
 22  peakrpm           205 non-null    int64  
 23  citympg           205 non-null    int64  
 24  highwaympg        205 non-null    int64  
 25  price             205 non-null    float64
dtypes: float64(8), int64(8), object(10)
memory usage: 41.8+ KB
In [23]:
data.isna().sum()
# 没有空值
Out[23]:
car_ID              0
symboling           0
CarName             0
fueltype            0
aspiration          0
doornumber          0
carbody             0
drivewheel          0
enginelocation      0
wheelbase           0
carlength           0
carwidth            0
carheight           0
curbweight          0
enginetype          0
cylindernumber      0
enginesize          0
fuelsystem          0
boreratio           0
stroke              0
compressionratio    0
horsepower          0
peakrpm             0
citympg             0
highwaympg          0
price               0
dtype: int64
In [24]:
data.duplicated().sum()
# 没有重复值
Out[24]:
0
In [25]:
data.groupby("CarName").sum(numeric_only=True)
Out[25]:
car_ID symboling wheelbase carlength carwidth carheight curbweight enginesize boreratio stroke compressionratio horsepower peakrpm citympg highwaympg price
CarName
Nissan versa 90 1 94.5 165.3 63.8 54.5 1889 97 3.15 3.29 9.4 69 5200 31 37 5499.0
alfa-romero Quadrifoglio 3 1 94.5 171.2 65.5 52.4 2823 152 2.68 3.47 9.0 154 5000 19 26 16500.0
alfa-romero giulia 1 3 88.6 168.8 64.1 48.8 2548 130 3.47 2.68 9.0 111 5000 21 27 13495.0
alfa-romero stelvio 2 3 88.6 168.8 64.1 48.8 2548 130 3.47 2.68 9.0 111 5000 21 27 16500.0
audi 100 ls 4 2 99.8 176.6 66.2 54.3 2337 109 3.19 3.40 10.0 102 5500 24 30 13950.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
volvo 246 204 -1 109.1 188.8 68.9 55.5 3217 145 3.01 3.40 23.0 106 4800 26 27 22470.0
volvo 264gl 404 -3 213.4 377.6 136.1 111.7 6107 271 7.40 6.30 17.0 276 10500 36 47 41045.0
volvo diesel 200 -1 104.3 188.8 67.2 57.5 3157 130 3.62 3.15 7.5 162 5100 17 22 18950.0
vw dasher 190 3 94.5 159.3 64.2 55.6 2254 109 3.19 3.40 8.5 90 5500 24 29 11595.0
vw rabbit 191 3 94.5 165.7 64.0 51.4 2221 109 3.19 3.40 8.5 90 5500 24 29 9980.0

147 rows × 16 columns

In [26]:
# 删除 CarName, CarID 因为它不会给回归任务增加太多价值
data = data.drop(['car_ID','CarName'],axis=1)
data.head(1)
Out[26]:
symboling fueltype aspiration doornumber carbody drivewheel enginelocation wheelbase carlength carwidth ... enginesize fuelsystem boreratio stroke compressionratio horsepower peakrpm citympg highwaympg price
0 3 gas std two convertible rwd front 88.6 168.8 64.1 ... 130 mpfi 3.47 2.68 9.0 111 5000 21 27 13495.0

1 rows × 24 columns

In [27]:
sns.histplot(data=data, x="price")
Out[27]:
<Axes: xlabel='price', ylabel='Count'>
In [28]:
plt.figure(figsize=(15,7))
sns.heatmap(data.corr(numeric_only=True), annot=True)
plt.title("Data Correlation",size=15)
plt.show()
In [29]:
#燃料类型对价格的影响
sns.barplot(x="fueltype", y="price", data=data)
Out[29]:
<Axes: xlabel='fueltype', ylabel='price'>
In [30]:
#车型对价格的影响
sns.boxplot(x ="carbody", y ="price", data = data)
Out[30]:
<Axes: xlabel='carbody', ylabel='price'>
In [31]:
#门数对价格的影响
sns.boxplot(x ="doornumber", y ="price", data = data)
Out[31]:
<Axes: xlabel='doornumber', ylabel='price'>
In [32]:
#驱动器(FWD、RWD、AWD)对价格的影响
sns.boxplot(x ="drivewheel", y ="price", data = data)
Out[32]:
<Axes: xlabel='drivewheel', ylabel='price'>
In [33]:
#绘制热图中最相关属性之间的成对关系
columns=data[['wheelbase','carlength','carwidth','curbweight','price']]
sns.pairplot(columns)
plt.show()
#linear relationship
In [34]:
columns=data[['horsepower','citympg','highwaympg','price']]
sns.pairplot(columns)
plt.show()
#linear relationship

以下属性集具有线性关系:

1.轴距、车长、车宽、整备质量和价格(基本上是所有物理属性)

2.马力、城市英里数、高速公路英里数和价格(基本上是与车辆功率相关的所有属性)

2. 训练模型¶

In [35]:
encoder = LabelEncoder()
data['fueltype'] = encoder.fit_transform(data['fueltype'])
fueltype = {index : label for index, label in enumerate(encoder.classes_)}
data['aspiration'] = encoder.fit_transform(data['aspiration'])
aspiration = {index : label for index, label in enumerate(encoder.classes_)}
data['doornumber'] = encoder.fit_transform(data['doornumber'])
doornumber = {index : label for index, label in enumerate(encoder.classes_)}
data['carbody'] = encoder.fit_transform(data['carbody'])
carbody = {index : label for index, label in enumerate(encoder.classes_)}
data['drivewheel'] = encoder.fit_transform(data['drivewheel'])
drivewheel = {index : label for index, label in enumerate(encoder.classes_)}
data['enginelocation'] = encoder.fit_transform(data['enginelocation'])
enginelocation = {index : label for index, label in enumerate(encoder.classes_)}
data['fuelsystem'] = encoder.fit_transform(data['fuelsystem'])
fuelsystem = {index : label for index, label in enumerate(encoder.classes_)}
data['enginetype'] = encoder.fit_transform(data['enginetype'])
enginetype = {index : label for index, label in enumerate(encoder.classes_)}
data['cylindernumber'] = encoder.fit_transform(data['cylindernumber'])
cylindernumber = {index : label for index, label in enumerate(encoder.classes_)}
data['fuelsystem'] = encoder.fit_transform(data['fuelsystem'])
fuelsystem = {index : label for index, label in enumerate(encoder.classes_)}
In [36]:
x = data.drop('price', axis=1)
y = data['price']
In [37]:
scaler = MinMaxScaler(copy=True, feature_range=(0, 1))
X = scaler.fit_transform(x)
In [38]:
#train, test split
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=30,random_state=0)
  1. 随机森林回归
In [39]:
rf = RandomForestRegressor(n_estimators=100,max_depth=5, random_state=33)
rf.fit(x_train, y_train)
Out[39]:
RandomForestRegressor(max_depth=5, random_state=33)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestRegressor(max_depth=5, random_state=33)
In [40]:
print("Training r2_score: ",rf.score(x_train, y_train))
print("Testing r2_score: ",rf.score(x_test, y_test))
Training r2_score:  0.9753559007565417
Testing r2_score:  0.87367804775233
  1. 决策树回归
In [41]:
dt = DecisionTreeRegressor( max_depth=5,random_state=33)
dt.fit(x_train, y_train)
Out[41]:
DecisionTreeRegressor(max_depth=5, random_state=33)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
DecisionTreeRegressor(max_depth=5, random_state=33)
In [42]:
print('Training r2_score: ' , dt.score(x_train, y_train))
print('Testing r2_score: ' , dt.score(x_test, y_test))
Training r2_score:  0.9735394081185511
Testing r2_score:  0.8226507572837073
  1. 线性回归
In [47]:
def evaluate(model,x_train , y_train, x_test , y_test, y_predict):
    print(f'train r2_score:{r2_score(y_train, model.predict(x_train))}' )
    print(f'test r2_score : {r2_score(y_test, y_predict)}')
In [48]:
model = LinearRegression()
model.fit(x_train,y_train)
y_predict=model.predict(x_test)
evaluate(model,x_train , y_train, x_test , y_test, y_predict)
train r2_score:0.889157847638672
test r2_score : 0.7289860743863041