In this notebook, we will explore the severity of crashes, as it will be the target of our predictive models.
from pathlib import Path
import numpy as np
import pandas as pd
import scipy.stats as st
import matplotlib.pyplot as plt
import seaborn as sb
from crash_prediction import cas_data
# set seaborn default style
sb.set()
But first, we ensure we have the data or download it if needed
dset_path = Path("..") / "data" / "cas_dataset.csv"
if not dset_path.exists():
dset_path.parent.mkdir(parents=True, exist_ok=True)
cas_data.download(dset_path)
and load it.
dset = pd.read_csv(dset_path)
dset.head()
X | Y | OBJECTID | advisorySpeed | areaUnitID | bicycle | bridge | bus | carStationWagon | cliffBank | ... | train | tree | truck | unknownVehicleType | urban | vanOrUtility | vehicle | waterRiver | weatherA | weatherB | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 174.802308 | -41.287738 | 1 | NaN | 576700.0 | 0.0 | 0.0 | 0.0 | 2.0 | 1.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | Urban | 0.0 | 0.0 | 0.0 | Fine | Null |
1 | 174.423436 | -36.077085 | 2 | NaN | 501815.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | Open | 0.0 | 0.0 | 0.0 | Fine | Null |
2 | 174.885342 | -36.985339 | 3 | NaN | 523722.0 | 0.0 | NaN | 0.0 | 1.0 | NaN | ... | NaN | NaN | 0.0 | 0.0 | Urban | 1.0 | NaN | NaN | Fine | Null |
3 | 173.897841 | -35.917064 | 4 | NaN | 504502.0 | 1.0 | NaN | 0.0 | 1.0 | NaN | ... | NaN | NaN | 0.0 | 0.0 | Open | 0.0 | NaN | NaN | Fine | Null |
4 | 174.761113 | -36.863304 | 5 | NaN | 514301.0 | 0.0 | NaN | 1.0 | 1.0 | NaN | ... | NaN | NaN | 0.0 | 0.0 | Urban | 0.0 | NaN | NaN | Fine | Null |
5 rows × 72 columns
The CAS dataset has 4 features that can be associated with the crash severity:
crashSeverity
, severity of a crash, determined by the worst injury
sustained in the crash at time of entry,fatalCount
, count of the number of fatal casualties associated with this
crash,minorInjuryCount
, count of the number of minor injuries associated with
this crash,seriousInjuryCount
, count of the number of serious injuries associated
with this crash.severity_features = [
"fatalCount",
"seriousInjuryCount",
"minorInjuryCount",
"crashSeverity",
]
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
for ax, feat in zip(axes.flat, severity_features):
counts = dset[feat].value_counts(dropna=False)
counts.plot.bar(ylabel="# crashes", title=feat, ax=ax)
ax.set(yscale="log")
fig.tight_layout()
To check the geographical distribution, we will focus on Auckland and replace
discrete levels of crashSeverity
with number to ease plotting.
dset_auckland = dset[dset["X"].between(174.7, 174.9) & dset["Y"].between(-37, -36.8)]
mapping = {
"Non-Injury Crash": 1,
"Minor Crash": 2,
"Serious Crash": 3,
"Fatal Crash": 4,
}
dset_auckland = dset_auckland.replace({"crashSeverity": mapping})
Given the data set imbalance, we plot the local maxima to better see the location of more severe car crashes.
fig, axes = plt.subplots(2, 2, figsize=(15, 15))
for ax, feat in zip(axes.flat, severity_features):
dset_auckland.plot.hexbin(
"X",
"Y",
feat,
gridsize=500,
reduce_C_function=np.max,
cmap="BuPu",
title=feat,
ax=ax,
sharex=False,
)
ax.set_xticklabels([])
ax.set_yticklabels([])
fig.tight_layout()
Few remarks coming from these plots:
The crash severity is probably a good go-to target, as it's quite interpretable and actionable. The corresponding ML problem is a supervised multi-class prediction problem.
To simplify the problem, we can also just try to predict if a crash is going to involve an injury (minor, severe or fatal) or none. Here is how it would look like in Auckland
dset_auckland["injuryCrash"] = (dset_auckland["crashSeverity"] > 1) * 1.0
dset_auckland.plot.hexbin(
"X",
"Y",
"injuryCrash",
gridsize=500,
cmap="BuPu",
title="Crash with injury",
sharex=False,
figsize=(10, 10),
)
<AxesSubplot:title={'center':'Crash with injury'}, xlabel='X', ylabel='Y'>
Interestingly, the major axes do not pop up as saliently here, as we are averaging instead of taking the local maxima.
This brings us to to the another question: is the fraction of crash with injuries constant fraction of the number of crashes in an area? This would imply that a simple binomial model can model locally binned data.
We first discretize space into 0.01° wide cells and count the total number of crashes in each cell as well as the number of crashes with injuries.
dset["X_bin"] = pd.cut(
dset["X"], pd.interval_range(dset.X.min(), dset.X.max(), freq=0.01)
)
dset["Y_bin"] = pd.cut(
dset["Y"], pd.interval_range(dset.Y.min(), dset.Y.max(), freq=0.01)
)
counts = (
dset.groupby(["X_bin", "Y_bin"], observed=True).size().reset_index(name="crash")
)
injury_counts = (
dset.groupby(["X_bin", "Y_bin"], observed=True)
.apply(lambda x: (x["crashSeverity"] != "Non-Injury Crash").sum())
.reset_index(name="injury")
)
counts = counts.merge(injury_counts)
For each number of crashes in cells, we can check the fraction of crashes with injuries. Here we see that cells with 1 or few crashes have a nearly 50/50 chance of injuries, compared to cells with a larger number of accidents, where it goes down to about 20%.
injury_fraction = counts.groupby("crash").apply(
lambda x: x["injury"].sum() / x["crash"].sum()
)
ax = injury_fraction.plot(style=".", ylabel="fraction of injuries", figsize=(10, 7))
ax.set_xscale("log")
Then we can also check how good is a binomial distribution at modeling binned data, using it to derive a 95% predictive interval.
ratio = counts["injury"].sum() / counts["crash"].sum()
xs = np.arange(1, counts["crash"].max() + 1)
pred_intervals = st.binom(xs, ratio).ppf([[0.025], [0.975]])
fig, axes = plt.subplots(1, 2, figsize=(15, 7))
counts.plot.scatter(x="crash", y="injury", alpha=0.3, c="b", s=2, ax=axes[0])
axes[0].fill_between(
xs,
pred_intervals[0],
pred_intervals[1],
alpha=0.3,
color="r",
label="95% equal-tail interval for binomial",
)
axes[0].legend()
counts.plot.scatter(x="crash", y="injury", alpha=0.3, c="b", s=2, ax=axes[1])
axes[1].fill_between(
xs,
pred_intervals[0],
pred_intervals[1],
alpha=0.3,
color="r",
label="95% equal-tail interval for binomial",
)
axes[1].legend()
axes[1].set_xscale("log")
axes[1].set_yscale("log")
The predictive interval seems to have a poor coverage, overshooting the high counts regions and being to narrow for the regions with hundreds of crashes. We can compute the empirical coverage of these interval to check this.
counts["covered"] = counts["injury"].between(
pred_intervals[0, counts["crash"] - 1], pred_intervals[1, counts["crash"] - 1]
)
print(f"95% predictive interval has {counts['covered'].mean() * 100:.2f}%.")
95% predictive interval has 93.78%.
print("95% predictive interval coverage per quartile of crash counts:")
mask = counts["crash"] > 1
counts[mask].groupby(pd.qcut(counts.loc[mask, "crash"], 4))["covered"].mean()
95% predictive interval coverage per quartile of crash counts:
crash (1.999, 3.0] 0.958451 (3.0, 5.0] 0.962884 (5.0, 14.0] 0.924798 (14.0, 4685.0] 0.783500 Name: covered, dtype: float64
So it turns out that on a macro scale, the coverage of this simple model is quite good, but if we split by number of crashes, the coverage isn't so good anymore for the cells with higher number of crashes.
Hence, including the number of crashes in a vicinity could be an relevant predictor for the probability of crash with injury.
!date -R
Thu, 10 Dec 2020 19:25:08 +0000
!uname -a
Linux wbl009 3.10.0-693.2.2.el7.x86_64 #1 SMP Tue Sep 12 22:26:13 UTC 2017 x86_64 x86_64 x86_64 GNU/Linux
!pip freeze
affine==2.3.0 amply==0.1.4 appdirs==1.4.4 async-generator==1.10 attrs==20.3.0 backcall==0.2.0 black==20.8b1 bleach==3.2.1 bokeh==2.2.3 Cartopy @ file:///home/conda/feedstock_root/build_artifacts/cartopy_1604218104490/work certifi==2020.12.5 chardet==3.0.4 click==7.1.2 click-plugins==1.1.1 cligj==0.7.1 cloudpickle==1.6.0 ConfigArgParse==1.2.3 contextily==1.0.1 -e git+git@github.com:neon-ninja/crash_prediction.git@fc9214da20747e5e6cd459b122b0379395b3567c#egg=crash_prediction cycler==0.10.0 dask==2.30.0 dask-glm==0.2.0 dask-jobqueue==0.7.2 dask-ml==1.7.0 datrie==0.8.2 decorator==4.4.2 defopt==6.0.2 defusedxml==0.6.0 distributed==2.30.1 docutils==0.16 entrypoints==0.3 flake8==3.8.4 fsspec==0.8.4 geographiclib==1.50 geopy==2.0.0 gitdb==4.0.5 GitPython==3.1.11 HeapDict==1.0.1 idna==2.10 ipykernel==5.3.4 ipython==7.19.0 ipython-genutils==0.2.0 jedi==0.17.2 Jinja2==2.11.2 joblib==0.17.0 jsonschema==3.2.0 jupyter-client==6.1.7 jupyter-core==4.7.0 jupyterlab-pygments==0.1.2 jupytext==1.7.1 kiwisolver==1.3.1 lightgbm==3.1.1 llvmlite==0.35.0 locket==0.2.0 markdown-it-py==0.5.6 MarkupSafe==1.1.1 matplotlib==3.3.3 mccabe==0.6.1 mercantile==1.1.6 mistune==0.8.4 msgpack==1.0.0 multipledispatch==0.6.0 mypy-extensions==0.4.3 nbclient==0.5.1 nbconvert==6.0.7 nbformat==5.0.8 nest-asyncio==1.4.3 numba==0.52.0 numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1604945996350/work packaging==20.7 pandas==1.1.5 pandocfilters==1.4.3 parso==0.7.1 partd==1.1.0 pathspec==0.8.1 pexpect==4.8.0 pickleshare==0.7.5 Pillow==8.0.1 pockets==0.9.1 prompt-toolkit==3.0.8 psutil==5.7.3 ptyprocess==0.6.0 PuLP==2.3.1 pycodestyle==2.6.0 pyflakes==2.2.0 Pygments==2.7.3 pyparsing==2.4.7 pyrsistent==0.17.3 pyshp @ file:///home/conda/feedstock_root/build_artifacts/pyshp_1599782465740/work python-dateutil==2.8.1 pytz==2020.4 PyYAML==5.3.1 pyzmq==20.0.0 rasterio==1.1.8 ratelimiter==1.2.0.post0 regex==2020.11.13 requests==2.25.0 scikit-learn==0.23.2 scipy==1.5.4 seaborn==0.11.0 Shapely @ file:///home/conda/feedstock_root/build_artifacts/shapely_1602547954120/work six @ file:///home/conda/feedstock_root/build_artifacts/six_1590081179328/work smmap==3.0.4 snakemake==5.30.1 snuggs==1.4.7 sortedcontainers==2.3.0 sphinxcontrib-napoleon==0.7 tblib==1.7.0 testpath==0.4.4 threadpoolctl==2.1.0 toml==0.10.2 toolz==0.11.1 toposort==1.5 tornado==6.1 traitlets==5.0.5 typed-ast==1.4.1 typing-extensions==3.7.4.3 urllib3==1.26.2 wcwidth==0.2.5 webencodings==0.5.1 wrapt==1.12.1 zict==2.0.0