Learn Simple Linear Regression (SLR)

SLR in Python with statsmodels.api, statsmodels.formula.api, and scikit-learn

Harika Bonthu
6 min readMay 14, 2021
Simple linear regression graph (teal colored scattered points are actuals and red line is predicted values)

In this blog, we will

  • learn the basics of the Regression algorithm.
  • take a sample dataset, perform EDA(Exploratory Data Analysis) and implement SLR(Simple Linear Regression) using statsmodels.api, statsmodels.formula.api, and scikit-learn.

To begin with, what is Regression Algorithm?

Regression is a ‘Supervised machine learning’ algorithm used to predict continuous features.

Linear regression is the simplest regression algorithm that attempts to model the relationship between dependent variable and one or more independent variables by fitting a linear equation/best fit line to observed data.

Based on the number of input features, Linear regression could be of two types:

  • Simple Linear Regression (SLR)
  • Multiple Linear Regression (MLR)

In Simple Linear Regression (SLR), we will have a single input variable based on which we predict the output variable. Where in Multiple Linear Regression (MLR), we predict the output based on multiple inputs.

Input variables can also be termed as Independent/predictor variables, and the output variable is called dependent variable.

The equation for SLR is y=βo,+β1x+ϵ, where, Y is the dependent variable, X is the predictor, βo, β1 are coefficients/parameters of the model, and Epsilon(ϵ) is a random variable called Error Term.

OLS(Ordinary Least Squares), Gradient Descent are the two common algorithms to find the right coefficients for the minimum sum of squared errors.

Let’s begin by taking a small problem statement.

Problem statement: Build a simple linear regression model to predict the Salary Hike using Years of Experience.

Start by Importing necessary libraries

necessary libraries are pandas, numpy to work with dataframes, matplotlib, seaborn for visualizations, and sklearn, statsmodels to build regression models.

Importing necessary libraries

Once, we are done with importing libraries, we create a pandas dataframe from csv file

df = pd.read_csv(“Salary_Data.csv”)

Perform EDA (Exploratory Data Analysis)

The basic steps of EDA are:

  • Understand the dataset
  1. Identifying the number of features or columns
  2. Identifying the features or columns
  3. Identify the size of the dataset
  4. Identifying the data types of features
  5. Checking if the dataset has empty cells
  6. Identifying the number of empty cells by features or columns
  • Handling Missing Values and Outliers
  • Encoding Categorical variables
  • Graphical Univariate Analysis, Bivariate
  • Normalization and Scaling
EDA Steps

Our dataset has two columns: YearsExperience, Salary. And both are of float datatype. We have 30 records and no null-values or outliers in our dataset.

Graphical Univariate analysis

For univariate analysis, we have Histogram, density plot, boxplot or violinplot, and Normal Q-Q plot. They help us understand the distribution of the data points, and the presence of outliers.

A violin plot is a method of plotting numeric data. It is similar to a box plot, with the addition of a rotated kernel density plot on each side.

Graphical Univariate Analysis
Univariate graphical representations

From the above graphical representations, we can say there are no outliers in our data, andYearsExperience looks like normally distributed, and Salary doesn't look normal. We can verify this using Shapiro Test.

Shapiro test to check Normality

Our instinct from the graphs was correct. YearsExperience is normally distributed, and Salary isn’t normally distributed.

Bivariate visualization

for Numerical vs. Numerical data, we can plot the below graphs

  1. Scatterplot
  2. Line plot
  3. Heatmap for correlation
  4. Joint plot
Scatter and Line plots
Scatter and Line Plots
Heatmap
Heatmap
Joint Plot
Joint Plot

Check if there is any correlation between the variables using df.corr()

Correlation matric heatmap
heatmap of correlation

correlation =0.98, which is a high positive correlation. This means the dependent variable increases as the independent variable increases.

Normalization

As we can see, there is a huge difference between the values of YearsExperience, Salary columns. We can use Normalization to change the values of numeric columns in the dataset to use a common scale, without distorting differences in the ranges of values or losing information.

We use sklearn.preprocessing.Normalize to normalize our data. It returns values between 0 and 1.

Code block for Normalization

Linear Regression using scikit-learn

LinearRegression(): LinearRegression fits a linear model with coefficients β = (β1, …, βp) to minimize the residual sum of squares between the observed targets in the dataset, and the targets predicted by the linear approximation.

scikit-learn regressor code block

We achieved 95.7% accuracy using scikit-learn but there is not much scope to understand the in-depth insights about the relevance of features from this model. So let’s build a model using statsmodels.api, statsmodels.formula.api

Linear Regression using statsmodel.formula.api (smf)

The predictors in the statsmodels.formula.api must be enumerated individually. And in this method, a constant is automatically added to the data.

statsmodels.formula.api linear regression model code block
Bar plot of Actuals vs predicted values

Regression using statsmodels.api

The predictors are no longer have to be enumerated individually.

statsmodels.regression.linear_model.OLS(endog, exog)

  • endog is the dependent variable
  • exog is the independent variable. An intercept is not included by default and should be added by the user(using add_constant).
statsmodels.api linear regression model code block

We achieved 95.7% accuracy which is pretty good :)

What does the model summary table say??? 😕

It’s always important to understand certain terms from the regression model summary table so that we get to know the performance of our model and the relevance of the input variables.

OLS Regression Results summary

Some important parameters that should be considered are R-squared value, Adj. R-squared value, F-statistic, prob(F-statistic), coef of intercept and input variables, p>|t|.

  • R-Squared is the coefficient of determination. A statistical measure that says much percentage of data points are falling on the best fit line. An R-squared value closer to 1 is expected for a model to fit well.
  • Adj. R-squared penalizes the R-squared value if we keep adding the new features which are not contributing to the model prediction. If Adj. R-squared value < R-squared value, it’s a sign that we have irrelevant predictors in the model.
  • F-statistic or F-test helps us to accept or reject Null Hypothesis. It compares the intercept only model with our model with features. Null hypothesis is ‘all of the regression coefficients are equal to zero and that means both the models are equal’. Alternate hypothesis is ‘intercept only model is worse than our model, that means our added coefficients improved the model performance’. If prob(F-statistic) < 0.05 and F-statistic being a high value, we reject the Null hypothesis. It signifies that there is good relation between the input and the output variables.
  • coef shows the estimated coefficients of the corresponding input features
  • T-test talks about the relation between the output and each of the input variables individually. Null hypothesis is ‘coef of an input feature is 0’. Alternate hypothesis is ‘coef of an input feature is not 0’. If pvalue < 0.05, we reject null hypothesis which indicates that there is a good relationship between the input variable and the output variable. We can eliminate the variables whose pvalue is >0.05.

Well, now we know how to draw important inferences from model summary table, so now let’s look at our model parameters and evaluate our model.

In our case, R-squared value (0.957) is close to Adj. R-squared value (0.955) which is a good sign that the input features are contributing to the predictor model.

F-statistic is a high number and p(F-statistic) is almost 0, which means our model is better than the only intercept model.

pvalue of t-test for input variable is less than 0.05, so there is good relation between the input and the output variable.

Hence, we conclude saying our model is performing well ✔😊

In this blog, we learned the basics of Simple Linear Regression (SLR), building a linear model using different python libraries, and drawing inferences from the summary table of OLS statsmodels.

References:

Interpreting the summary table from OLS statsmodel

Visualizations: Histogram, Density plot, violin plot, boxplot, Normal Q-Q plot, Scatterplot, lineplot, heatmap, jointplot

Check out the complete notebook from my GitHub repository.

Hoping this is an informative blog for beginners. Please upvote if you find this helpful 🙌 Feedback is highly appreciated. Happy Learning !! 😎

--

--

No responses yet