from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sb
import contextily as cx
from crash_prediction import cas_data
# set seaborn default style
sb.set()
First we need to retrieve the dataset from the Open Data portal. Multiple file formats are available (csv, kml, geojson, ...), the most compact being the .csv one.
dset_path = Path("..") / "data" / "cas_dataset.csv"
if not dset_path.exists():
dset_path.parent.mkdir(parents=True, exist_ok=True)
cas_data.download(dset_path)
Next we load the data and have a quick look to check if there no obvious loading error.
dset = pd.read_csv(dset_path)
dset
X | Y | OBJECTID | advisorySpeed | areaUnitID | bicycle | bridge | bus | carStationWagon | cliffBank | ... | train | tree | truck | unknownVehicleType | urban | vanOrUtility | vehicle | waterRiver | weatherA | weatherB | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 174.802308 | -41.287738 | 1 | NaN | 576700.0 | 0.0 | 0.0 | 0.0 | 2.0 | 1.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | Urban | 0.0 | 0.0 | 0.0 | Fine | Null |
1 | 174.423436 | -36.077085 | 2 | NaN | 501815.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | Open | 0.0 | 0.0 | 0.0 | Fine | Null |
2 | 174.885342 | -36.985339 | 3 | NaN | 523722.0 | 0.0 | NaN | 0.0 | 1.0 | NaN | ... | NaN | NaN | 0.0 | 0.0 | Urban | 1.0 | NaN | NaN | Fine | Null |
3 | 173.897841 | -35.917064 | 4 | NaN | 504502.0 | 1.0 | NaN | 0.0 | 1.0 | NaN | ... | NaN | NaN | 0.0 | 0.0 | Open | 0.0 | NaN | NaN | Fine | Null |
4 | 174.761113 | -36.863304 | 5 | NaN | 514301.0 | 0.0 | NaN | 1.0 | 1.0 | NaN | ... | NaN | NaN | 0.0 | 0.0 | Urban | 0.0 | NaN | NaN | Fine | Null |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
733766 | 174.746739 | -36.889195 | 733767 | NaN | 518301.0 | 0.0 | NaN | 0.0 | 2.0 | NaN | ... | NaN | NaN | 0.0 | 0.0 | Urban | 0.0 | NaN | NaN | Fine | Null |
733767 | 174.659241 | -36.831897 | 733768 | NaN | 512202.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | Urban | 0.0 | 0.0 | 0.0 | Fine | Null |
733768 | 174.845557 | -36.946666 | 733769 | NaN | 521802.0 | 0.0 | NaN | 0.0 | 2.0 | NaN | ... | NaN | NaN | 0.0 | 0.0 | Urban | 0.0 | NaN | NaN | Fine | Null |
733769 | 172.832776 | -42.514413 | 733770 | NaN | 585502.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | ... | 0.0 | 1.0 | 0.0 | 0.0 | Urban | 0.0 | 0.0 | 0.0 | Mist or Fog | Null |
733770 | 174.850515 | -41.139181 | 733771 | NaN | 570900.0 | 0.0 | NaN | 0.0 | 2.0 | NaN | ... | NaN | NaN | 0.0 | 0.0 | Urban | 0.0 | NaN | NaN | Fine | Null |
733771 rows × 72 columns
The dataset contains 72 columns, describing various aspects of the recorded car crashes. The full description of the fields is available online, see https://opendata-nzta.opendata.arcgis.com/pages/cas-data-field-descriptions.
dset.columns
Index(['X', 'Y', 'OBJECTID', 'advisorySpeed', 'areaUnitID', 'bicycle', 'bridge', 'bus', 'carStationWagon', 'cliffBank', 'crashDirectionDescription', 'crashFinancialYear', 'crashLocation1', 'crashLocation2', 'crashRoadSideRoad', 'crashSeverity', 'crashSHDescription', 'crashYear', 'debris', 'directionRoleDescription', 'ditch', 'fatalCount', 'fence', 'flatHill', 'guardRail', 'holiday', 'houseOrBuilding', 'intersection', 'kerb', 'light', 'meshblockId', 'minorInjuryCount', 'moped', 'motorcycle', 'NumberOfLanes', 'objectThrownOrDropped', 'otherObject', 'otherVehicleType', 'overBank', 'parkedVehicle', 'pedestrian', 'phoneBoxEtc', 'postOrPole', 'region', 'roadCharacter', 'roadLane', 'roadSurface', 'roadworks', 'schoolBus', 'seriousInjuryCount', 'slipOrFlood', 'speedLimit', 'strayAnimal', 'streetLight', 'suv', 'taxi', 'temporarySpeedLimit', 'tlaId', 'tlaName', 'trafficControl', 'trafficIsland', 'trafficSign', 'train', 'tree', 'truck', 'unknownVehicleType', 'urban', 'vanOrUtility', 'vehicle', 'waterRiver', 'weatherA', 'weatherB'], dtype='object')
Note that X
and Y
are geographical coordinates using the WGS84 coordinate
system (see EPSG:4326).
First, we will look at the location of the crashes. More accidents happen in densier areas and it would be good to compare with population density.
Note: We removed Chatham island data here to ease plotting.
def plot_hexmap(dset, ax=None):
if ax is None:
_, ax = plt.subplots(figsize=(10, 10))
hb = ax.hexbin(
dset["X"], dset["Y"], gridsize=500, cmap="BuPu", mincnt=1, bins="log"
)
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
cx.add_basemap(ax, crs=4326, source=cx.providers.CartoDB.Positron)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
ax.figure.colorbar(hb, cax=cax)
return ax
plot_hexmap(dset[dset.X > 0])
<AxesSubplot:xlabel='Longitude', ylabel='Latitude'>
In dense aread, like in Auckland, there are enough crashes events to map the local road network.
dset_auckland = dset[dset["X"].between(174.7, 174.9) & dset["Y"].between(-37, -36.8)]
plot_hexmap(dset_auckland)
<AxesSubplot:xlabel='Longitude', ylabel='Latitude'>
At a coarser level, there is also the region information.
region_perc = dset["region"].value_counts(normalize=True)
ax = region_perc.plot.bar(ylabel="fraction of crashes", figsize=(10, 5))
_ = ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
print(
f"The top 4 regions account for {region_perc.nlargest(4).sum() * 100:0.1f}% "
"of the crashes."
)
The top 4 regions account for 65.9% of the crashes.
The dataset contains few temporal features:
crashYear
and crashFinancialYear
, respectively the year and final year
of each crash,holiday
, whether it occurs during a holiday period.So we won't be able to study daily, weekly and yearly patterns with these data.
If we look at the yearly counts, we can see some fluctuations, mostly driven by Auckland region but still noticeable in other parts of the country. Year 2020 is much lower as it's the current year.
year_counts = dset["crashYear"].value_counts(sort=False)
_ = year_counts.plot.bar(ylabel="# crashes", figsize=(10, 5))
year_region_counts = (
dset.groupby(["crashYear", "region"]).size().reset_index(name="# crashes")
)
_, ax = plt.subplots(figsize=(10, 5))
sb.pointplot(data=year_region_counts, x="crashYear", y="# crashes", hue="region", ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
_ = plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
We can also explore the spatio-temporal patterns too. Here we focus on Auckland (excluding 2020).
grid = sb.FacetGrid(
dset_auckland[dset_auckland.crashYear < 2020], col="crashYear", col_wrap=5
)
grid.map(plt.hexbin, "X", "Y", gridsize=500, cmap="BuPu", mincnt=1, bins="log")
<seaborn.axisgrid.FacetGrid at 0x2aab0cd167c0>
The other temporal attribute is the holiday. Christmas is the holiday period with most of the accidents. How the period is computed is not clear, so the larger amount of accident could be partly due to the time extent. Easter, Queens Birthday and Labour weekend are 3 to 4 days periods. Christmas & New Year is probably 1 to 2 weeks period.
holiday_counts = dset["holiday"].fillna("Normal day").value_counts()
ax = holiday_counts.plot.bar(ylabel="# crashes", rot=0, figsize=(10, 5))
_ = ax.set(yscale="log")
From the dataset fields description, the following features seem specific to the type of road:
crashSHDescription
, whether the crash happened on a state highway,flatHill
, whether the road is flat or sloped,junctionType
, type of junction the crash happened at (may also be unknown
& crashes not occurring at a junction are also unknown),NumberOfLanes
, number of lanes on the crash road,roadCharacter
, general nature of the road,roadCurvature
, simplified curvature of the road,roadLane
, lane configuration of the road (' ' for unknown or invalid
configurations),roadMarkings
, road markings at the crash site,roadSurface
, road surface description applying at the crash site,speedLimit
, speed limit in force at the crash site at the time of the
crash (number, or 'LSZ' for a limited speed zone),streetLight
, street lighting at the time of the crash (this is also
a sort of temporal information),urban
, whether the road is in an urban area (derived from speed limit).Unfortunately, not all fields are actually available in the dataset.
road_features = set(
[
"crashSHDescription",
"flatHill",
"junctionType",
"NumberOfLanes",
"roadCharacter",
"roadCurvature",
"roadLane",
"roadMarkings",
"roadSurface",
"speedLimit",
"streetLight",
"urban",
]
)
missing_features = road_features - set(dset.columns)
road_features -= missing_features
print("The following features are not found in the dataset:", missing_features)
The following features are not found in the dataset: {'roadMarkings', 'junctionType', 'roadCurvature'}
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
for ax, feat in zip(axes.flat, sorted(road_features)):
counts = dset[feat].value_counts(dropna=False)
counts.plot.bar(ylabel="# crashes", title=feat, ax=ax)
ax.set(yscale="log")
fig.tight_layout()
The urban
feature is derived from speedLimit
, so we can probably remove it.
The environmental features are weather and sunhsine:
light
, light at the time and place of the crash (this is also a sort of
temporal information),weatherA
and weatherB
, weather at the crash time/place.env_features = ["light", "weatherA", "weatherB"]
fig, axes = plt.subplots(1, 3, figsize=(13, 4))
for ax, feat in zip(axes.flat, env_features):
counts = dset[feat].value_counts(dropna=False)
counts.plot.bar(ylabel="# crashes", title=feat, ax=ax)
ax.set(yscale="log")
fig.tight_layout()
We have checked the spatial, temporal, road and environmental features related to the accidents.
If these features inform us in which conditions there are more accidents relatively, we will need additional baseline information if we want to create a predictive model.
For the road features we could use a LINZ dataset or another NZTA dataset that brings more information about the road type and traffic. But then we need to attribute each crash to a road.
Another option would be to regrid the data, and for each cell containing at least one crash event we associate road features from the crash events. With this option, we don't make any prediction for cells of the grid where we don't have information about.
For the environmental information, we need weather information for all days in a year:
The prediction task can be formulated in different ways:
exclude weather & holiday features, and fit a regression model with count data using year & location features (and accounting for traffic volume to compare the number of crashes per car on the road),
group data by location, time, weather type (e.g. rain vs. no rain), and perform a binomial regression using the total number of days in each category (e.g. number of rain days for a particular location & year),
predict crash severity from the whole dataset, assuming the non-severe crashes are a good proxy for normal conditions (weather, holidays, etc.).
!date -R
Thu, 10 Dec 2020 19:25:49 +0000
!uname -a
Linux wbl009 3.10.0-693.2.2.el7.x86_64 #1 SMP Tue Sep 12 22:26:13 UTC 2017 x86_64 x86_64 x86_64 GNU/Linux
!pip freeze
affine==2.3.0 amply==0.1.4 appdirs==1.4.4 async-generator==1.10 attrs==20.3.0 backcall==0.2.0 black==20.8b1 bleach==3.2.1 bokeh==2.2.3 Cartopy @ file:///home/conda/feedstock_root/build_artifacts/cartopy_1604218104490/work certifi==2020.12.5 chardet==3.0.4 click==7.1.2 click-plugins==1.1.1 cligj==0.7.1 cloudpickle==1.6.0 ConfigArgParse==1.2.3 contextily==1.0.1 -e git+git@github.com:neon-ninja/crash_prediction.git@fc9214da20747e5e6cd459b122b0379395b3567c#egg=crash_prediction cycler==0.10.0 dask==2.30.0 dask-glm==0.2.0 dask-jobqueue==0.7.2 dask-ml==1.7.0 datrie==0.8.2 decorator==4.4.2 defopt==6.0.2 defusedxml==0.6.0 distributed==2.30.1 docutils==0.16 entrypoints==0.3 flake8==3.8.4 fsspec==0.8.4 geographiclib==1.50 geopy==2.0.0 gitdb==4.0.5 GitPython==3.1.11 HeapDict==1.0.1 idna==2.10 ipykernel==5.3.4 ipython==7.19.0 ipython-genutils==0.2.0 jedi==0.17.2 Jinja2==2.11.2 joblib==0.17.0 jsonschema==3.2.0 jupyter-client==6.1.7 jupyter-core==4.7.0 jupyterlab-pygments==0.1.2 jupytext==1.7.1 kiwisolver==1.3.1 lightgbm==3.1.1 llvmlite==0.35.0 locket==0.2.0 markdown-it-py==0.5.6 MarkupSafe==1.1.1 matplotlib==3.3.3 mccabe==0.6.1 mercantile==1.1.6 mistune==0.8.4 msgpack==1.0.0 multipledispatch==0.6.0 mypy-extensions==0.4.3 nbclient==0.5.1 nbconvert==6.0.7 nbformat==5.0.8 nest-asyncio==1.4.3 numba==0.52.0 numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1604945996350/work packaging==20.7 pandas==1.1.5 pandocfilters==1.4.3 parso==0.7.1 partd==1.1.0 pathspec==0.8.1 pexpect==4.8.0 pickleshare==0.7.5 Pillow==8.0.1 pockets==0.9.1 prompt-toolkit==3.0.8 psutil==5.7.3 ptyprocess==0.6.0 PuLP==2.3.1 pycodestyle==2.6.0 pyflakes==2.2.0 Pygments==2.7.3 pyparsing==2.4.7 pyrsistent==0.17.3 pyshp @ file:///home/conda/feedstock_root/build_artifacts/pyshp_1599782465740/work python-dateutil==2.8.1 pytz==2020.4 PyYAML==5.3.1 pyzmq==20.0.0 rasterio==1.1.8 ratelimiter==1.2.0.post0 regex==2020.11.13 requests==2.25.0 scikit-learn==0.23.2 scipy==1.5.4 seaborn==0.11.0 Shapely @ file:///home/conda/feedstock_root/build_artifacts/shapely_1602547954120/work six @ file:///home/conda/feedstock_root/build_artifacts/six_1590081179328/work smmap==3.0.4 snakemake==5.30.1 snuggs==1.4.7 sortedcontainers==2.3.0 sphinxcontrib-napoleon==0.7 tblib==1.7.0 testpath==0.4.4 threadpoolctl==2.1.0 toml==0.10.2 toolz==0.11.1 toposort==1.5 tornado==6.1 traitlets==5.0.5 typed-ast==1.4.1 typing-extensions==3.7.4.3 urllib3==1.26.2 wcwidth==0.2.5 webencodings==0.5.1 wrapt==1.12.1 zict==2.0.0