Chapter 6. Asteroid spectral classification¶
Additional material to chapter 6: Asteroid spectro-photometric characterisation in eds. Valerio Carruba, Evgeny Smirnov, Dagmara Anna Oszkiewicz, Machine Learning for Small Bodies in the Solar System, Elsevier 2024.
Authors of chapter 6 and the following material:
Dagmara Oszkiewicz1, Antti Penttilä2, and Hanna Klimczak-Plucinska1
1Astronomical Observatory Institute, Faculty of Physics, Adam Mickiewicz University, Sªoneczna 36, 60-286 Poznan, Poland
2Department of Physics, University of Helsinki, Finland
03-04-2024
We need the following Python packages.
# For data simulation with PCA
import pandas as pd
import numpy as np
import scipy.linalg
%matplotlib inline
from matplotlib import pyplot as plt
from prettytable import PrettyTable
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import warnings
warnings.filterwarnings("ignore")
# For classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
from collections import Counter
We will use our own collection of asteroid spectral data and taxonomic labels that is based on the SMASSII data and the MITHNEOS data, but we have joined them and simplified the taxonomic labels. First, let's create the wavelengths of the data. There is no wavelength 550 nm in the data, since it is normalized there so the value would always be 1.
temp1 = np.arange(450,550,10)
temp2 = np.arange(560,2455,10)
wavelengths = np.concatenate([temp1,temp2])
len(wavelengths)
200
Then, read the data in as a Pandas dataframe. The wavelengths are used as column labels, togehter with the label 'tax' for taxonomy.
astdata = pd.read_csv('asteroid-spectral-data.csv',header=None,names=np.concatenate([['tax'],wavelengths]))
astdata
tax | 450 | 460 | 470 | 480 | 490 | 500 | 510 | 520 | 530 | ... | 2360 | 2370 | 2380 | 2390 | 2400 | 2410 | 2420 | 2430 | 2440 | 2450 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | C | 0.936682 | 0.948220 | 0.958273 | 0.966964 | 0.974415 | 0.980751 | 0.986094 | 0.990568 | 0.994294 | ... | 1.002864 | 1.003190 | 1.003550 | 1.003946 | 1.004382 | 1.004860 | 1.005383 | 1.005955 | 1.006578 | 1.007254 |
1 | B | 0.972303 | 0.977751 | 0.982461 | 0.986487 | 0.989882 | 0.992699 | 0.994993 | 0.996816 | 0.998223 | ... | 0.836224 | 0.835711 | 0.835245 | 0.834830 | 0.834470 | 0.834168 | 0.833927 | 0.833752 | 0.833646 | 0.833612 |
2 | S | 0.882684 | 0.895862 | 0.910649 | 0.924991 | 0.937650 | 0.948964 | 0.959456 | 0.969644 | 0.979917 | ... | 1.256736 | 1.257612 | 1.258852 | 1.260557 | 1.262824 | 1.265675 | 1.268821 | 1.271896 | 1.274532 | 1.276362 |
3 | V | 0.901729 | 0.912059 | 0.924331 | 0.937435 | 0.950260 | 0.961788 | 0.971787 | 0.980430 | 0.987896 | ... | 1.118528 | 1.125348 | 1.132181 | 1.138816 | 1.145046 | 1.150661 | 1.155450 | 1.159206 | 1.161719 | 1.162779 |
4 | S | 0.857128 | 0.876266 | 0.892801 | 0.907897 | 0.922509 | 0.936765 | 0.950582 | 0.963880 | 0.976577 | ... | 1.287134 | 1.288773 | 1.290575 | 1.292559 | 1.294742 | 1.297205 | 1.300269 | 1.304318 | 1.309735 | 1.316905 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
581 | L | 0.964614 | 0.956561 | 0.951836 | 0.950159 | 0.951247 | 0.954816 | 0.960585 | 0.968270 | 0.977589 | ... | 1.707470 | 1.715092 | 1.723149 | 1.731653 | 1.740618 | 1.750054 | 1.759973 | 1.770388 | 1.781310 | 1.792751 |
582 | S | 0.778583 | 0.818356 | 0.852047 | 0.880469 | 0.904433 | 0.924750 | 0.942231 | 0.957689 | 0.971935 | ... | 1.384155 | 1.391469 | 1.398232 | 1.404404 | 1.409946 | 1.414820 | 1.418988 | 1.422410 | 1.425049 | 1.426865 |
583 | X | 1.041278 | 1.028260 | 1.018052 | 1.010335 | 1.004792 | 1.001103 | 0.998951 | 0.998016 | 0.998013 | ... | 1.226848 | 1.229681 | 1.232610 | 1.235639 | 1.238770 | 1.242007 | 1.245353 | 1.248810 | 1.252381 | 1.256070 |
584 | Q | 0.853625 | 0.864378 | 0.876155 | 0.889071 | 0.903245 | 0.918794 | 0.935665 | 0.953124 | 0.970268 | ... | 0.904400 | 0.903679 | 0.903112 | 0.902801 | 0.902851 | 0.903363 | 0.904442 | 0.906189 | 0.908709 | 0.912104 |
585 | L | 0.912581 | 0.924933 | 0.936311 | 0.946782 | 0.956412 | 0.965266 | 0.973412 | 0.980913 | 0.987822 | ... | 1.357655 | 1.360819 | 1.364037 | 1.367312 | 1.370643 | 1.374032 | 1.377481 | 1.380989 | 1.384559 | 1.388191 |
586 rows × 201 columns
We can see how the spectra looks like for the 5 first asteroids.
temp = astdata.iloc[0:5,1:].to_numpy()
plt.plot(wavelengths, temp.transpose());
plt.xlabel('wavelength');
plt.ylabel('reflectance');
plt.axhline(1, linestyle='--', color='k')
plt.show();
Save pandas dataframe for faster access.
astdata.to_hdf('astdata.h5','astdata',mode='w')
Simulate more spectra from existing samples using PCA¶
We will show how to create simulated samples from spectral taxonomic groups using conversion data -> PCA presentation -> data.
First, we will load the asteroid spectral data as pandas dataframe.
astdata = pd.read_hdf('astdata.h5')
astdata
tax | 450 | 460 | 470 | 480 | 490 | 500 | 510 | 520 | 530 | ... | 2360 | 2370 | 2380 | 2390 | 2400 | 2410 | 2420 | 2430 | 2440 | 2450 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | C | 0.936682 | 0.948220 | 0.958273 | 0.966964 | 0.974415 | 0.980751 | 0.986094 | 0.990568 | 0.994294 | ... | 1.002864 | 1.003190 | 1.003550 | 1.003946 | 1.004382 | 1.004860 | 1.005383 | 1.005955 | 1.006578 | 1.007254 |
1 | B | 0.972303 | 0.977751 | 0.982461 | 0.986487 | 0.989882 | 0.992699 | 0.994993 | 0.996816 | 0.998223 | ... | 0.836224 | 0.835711 | 0.835245 | 0.834830 | 0.834470 | 0.834168 | 0.833927 | 0.833752 | 0.833646 | 0.833612 |
2 | S | 0.882684 | 0.895862 | 0.910649 | 0.924991 | 0.937650 | 0.948964 | 0.959456 | 0.969644 | 0.979917 | ... | 1.256736 | 1.257612 | 1.258852 | 1.260557 | 1.262824 | 1.265675 | 1.268821 | 1.271896 | 1.274532 | 1.276362 |
3 | V | 0.901729 | 0.912059 | 0.924331 | 0.937435 | 0.950260 | 0.961788 | 0.971787 | 0.980430 | 0.987896 | ... | 1.118528 | 1.125348 | 1.132181 | 1.138816 | 1.145046 | 1.150661 | 1.155450 | 1.159206 | 1.161719 | 1.162779 |
4 | S | 0.857128 | 0.876266 | 0.892801 | 0.907897 | 0.922509 | 0.936765 | 0.950582 | 0.963880 | 0.976577 | ... | 1.287134 | 1.288773 | 1.290575 | 1.292559 | 1.294742 | 1.297205 | 1.300269 | 1.304318 | 1.309735 | 1.316905 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
581 | L | 0.964614 | 0.956561 | 0.951836 | 0.950159 | 0.951247 | 0.954816 | 0.960585 | 0.968270 | 0.977589 | ... | 1.707470 | 1.715092 | 1.723149 | 1.731653 | 1.740618 | 1.750054 | 1.759973 | 1.770388 | 1.781310 | 1.792751 |
582 | S | 0.778583 | 0.818356 | 0.852047 | 0.880469 | 0.904433 | 0.924750 | 0.942231 | 0.957689 | 0.971935 | ... | 1.384155 | 1.391469 | 1.398232 | 1.404404 | 1.409946 | 1.414820 | 1.418988 | 1.422410 | 1.425049 | 1.426865 |
583 | X | 1.041278 | 1.028260 | 1.018052 | 1.010335 | 1.004792 | 1.001103 | 0.998951 | 0.998016 | 0.998013 | ... | 1.226848 | 1.229681 | 1.232610 | 1.235639 | 1.238770 | 1.242007 | 1.245353 | 1.248810 | 1.252381 | 1.256070 |
584 | Q | 0.853625 | 0.864378 | 0.876155 | 0.889071 | 0.903245 | 0.918794 | 0.935665 | 0.953124 | 0.970268 | ... | 0.904400 | 0.903679 | 0.903112 | 0.902801 | 0.902851 | 0.903363 | 0.904442 | 0.906189 | 0.908709 | 0.912104 |
585 | L | 0.912581 | 0.924933 | 0.936311 | 0.946782 | 0.956412 | 0.965266 | 0.973412 | 0.980913 | 0.987822 | ... | 1.357655 | 1.360819 | 1.364037 | 1.367312 | 1.370643 | 1.374032 | 1.377481 | 1.380989 | 1.384559 | 1.388191 |
586 rows × 201 columns
Get wawelengths from column labels.
wavelengths = astdata.columns.to_numpy()[1:].astype('float')
len(wavelengths)
200
Group data by taxonomic label.
astgrdata = astdata.groupby('tax')
How many groups, what are the groups, and how many asteroids per group?
astgrdata.ngroups
astgrdata.size()
11
tax A 7 B 12 C 61 D 22 K 15 L 33 Q 43 S 310 T 4 V 28 X 51 dtype: int64
Data -> PCA¶
We will do PCA decomposition 'by hand' for each taxonomic group since we will need to do inverse transform in the end.
meanvecs = []
covmats = []
for tax, gr in astgrdata:
print("Group "+tax)
# Numerical data per group to numpy array
tempd = gr.iloc[:,1:].to_numpy()
print("Data shape: ", np.shape(tempd))
# Mean vector
mv = np.mean(tempd,axis=0)
meanvecs.append(mv)
# Center data
tempdc = tempd-mv
# Covariance matrix
cm = np.cov(tempdc,rowvar=False)
covmats.append(cm)
Group A Data shape: (7, 200) Group B Data shape: (12, 200) Group C Data shape: (61, 200) Group D Data shape: (22, 200) Group K Data shape: (15, 200) Group L Data shape: (33, 200) Group Q Data shape: (43, 200) Group S Data shape: (310, 200) Group T Data shape: (4, 200) Group V Data shape: (28, 200) Group X Data shape: (51, 200)
Eigenvalue decompositions of covariance matrices.
evalvecs = []
evecmats = []
for cm in covmats:
eval, evec = scipy.linalg.eig(cm)
evalvecs.append(np.real(eval))
evecmats.append(np.real(evec))
Table showing the cumulative variance of the 7 first PCA components for each taxonomy group
tab = PrettyTable(["Taxonomy", "PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6", "PCA7"])
taxlabs = astgrdata.groups.keys()
for tax,ev in zip(taxlabs,evalvecs):
t = np.concatenate([[tax],np.cumsum(ev[:7])/np.sum(ev)]).tolist()
tab.add_row(t)
tab
Taxonomy | PCA1 | PCA2 | PCA3 | PCA4 | PCA5 | PCA6 | PCA7 |
---|---|---|---|---|---|---|---|
A | 0.907769827275844 | 0.9739629506041639 | 0.9942138192058917 | 0.9980044727500169 | 0.9997828516127344 | 0.9999999999999998 | 0.9999999999999998 |
B | 0.8342119769869865 | 0.9387833423487731 | 0.9698329605476764 | 0.9877615261335111 | 0.9935758226011708 | 0.9968089508541326 | 0.9986770808287602 |
C | 0.8668553650371404 | 0.9632469049503382 | 0.9820141602629984 | 0.9891803028057304 | 0.9939147684161667 | 0.9966029640946612 | 0.9981533007516981 |
D | 0.9673534684251127 | 0.9869047313519237 | 0.995437080509162 | 0.9976094137287717 | 0.9984669336988631 | 0.99917603309234 | 0.999594921724622 |
K | 0.8899257263383968 | 0.9665352471352603 | 0.9870965800701219 | 0.9937501861112938 | 0.9977499971284651 | 0.9985772819520006 | 0.9991180406917984 |
L | 0.8917635488472414 | 0.960042364799775 | 0.9786187519368502 | 0.9904457258343683 | 0.9951043445548811 | 0.9974945839322459 | 0.9982666235529828 |
Q | 0.8512098199202154 | 0.9237449482182472 | 0.9474411396455018 | 0.9640414404744116 | 0.9766000105110583 | 0.9846290981168039 | 0.9888921500439156 |
S | 0.9030786044395169 | 0.9626088715535716 | 0.973664213112867 | 0.9826650260033087 | 0.9895810814477778 | 0.9932433722309181 | 0.9950313819120943 |
T | 0.9379045243965995 | 0.9777699979013184 | 0.9999999999999999 | 0.9999999999999999 | 0.9999999999999999 | 0.9999999999999999 | 0.9999999999999999 |
V | 0.8909842719933525 | 0.9452095591414079 | 0.9681707026406133 | 0.9847180543448252 | 0.9934305834527316 | 0.996248249354546 | 0.997158289389082 |
X | 0.9189602907767475 | 0.9828782893103468 | 0.9899547264957423 | 0.994425611670312 | 0.9963740010065031 | 0.9978561233486718 | 0.9985249808374639 |
Random sample based on PCA representation of the real spectra¶
Example on how to simulate spectral sample based on the PCA representation of the real samples.
Let's take one S-type spectra as an example.
one = astgrdata.get_group('S').iloc[0,1:].to_numpy()
plt.plot(wavelengths,one);
plt.xlabel('wavelength');
plt.ylabel('reflectance');
plt.axhline(1, linestyle='--', color='k')
plt.show();
Convert to PCA presentation (S is the index 7 label in grouped data).
onepca = (one-meanvecs[7]) @ evecmats[7]
Simulate random noise to components according to their eigenvalues. A scaling value of 0.6 to eigenvalues (variances) is applied.
rn = np.random.normal(scale=0.6*np.sqrt(np.clip(evalvecs[7],0,None)))
Add noise and transform back from PCA space.
onepca1 = onepca + rn
one1 = onepca1 @ evecmats[7].T + meanvecs[7]
Show orignal and random sample based on original.
plt.plot(wavelengths,one,wavelengths,one1);
plt.xlabel('wavelength');
plt.ylabel('reflectance');
plt.axhline(1, linestyle='--', color='k')
plt.show();
PCA -> data¶
Let's do the above example to all taxonomic groups and simulate 2000 samples for each taxonomy.
k = 2000
noisescale = 0.6
astsimdata = []
taxlab = list(astgrdata.groups.keys())
for i in range(astgrdata.ngroups):
print("Group "+taxlab[i])
d1 = astgrdata.get_group(taxlab[i]).iloc[:,1:].to_numpy()
print(str(len(d1))+" asteroids in group")
# base the simulated sample on random sample of the real spectra.
ranind = np.random.randint(0,high=len(d1),size=k)
sample = np.zeros((k,len(d1[0])))
for j in range(k):
basespectra = d1[ranind[j]]
# PCA representation
pca = (basespectra-meanvecs[i]) @ evecmats[i]
# Random change to PCA components
rannoise = np.random.normal(scale=noisescale*np.sqrt(np.clip(evalvecs[i],0,None)))
# Back to spectra
simspectra = (pca + rannoise) @ evecmats[i].T + meanvecs[i]
sample[j] = simspectra
astsimdata.append(sample)
Group A 7 asteroids in group Group B 12 asteroids in group Group C 61 asteroids in group Group D 22 asteroids in group Group K 15 asteroids in group Group L 33 asteroids in group Group Q 43 asteroids in group Group S 310 asteroids in group Group T 4 asteroids in group Group V 28 asteroids in group Group X 51 asteroids in group
Plot the samples from each group.
for i in range(astgrdata.ngroups):
_ = plt.plot(wavelengths,astsimdata[i].T)
_ = plt.xlabel('wavelength')
_ = plt.ylabel('reflectance')
_ = plt.axhline(1, linestyle='--', color='k')
_ = plt.title(taxlab[i])
plt.show()
Save the simulated sample.
# Concatenate into one matrix.
onematrix = np.concatenate(astsimdata)
# Do vector of labels
labs = []
for tax in taxlab:
labs.append(np.full((k,1),tax))
onelabs = np.concatenate(labs)
# Join labels and data
outmatrix = np.c_[onelabs, onematrix]
np.shape(outmatrix)
(22000, 201)
Save simulated samples as hdf5 file.
astsimdf = pd.DataFrame(outmatrix,columns=np.concatenate([['tax'],wavelengths]))
astsimdf.to_hdf('astdata-simulated.h5','astsimdata',mode='w')
astsimdf
tax | 450.0 | 460.0 | 470.0 | 480.0 | 490.0 | 500.0 | 510.0 | 520.0 | 530.0 | ... | 2360.0 | 2370.0 | 2380.0 | 2390.0 | 2400.0 | 2410.0 | 2420.0 | 2430.0 | 2440.0 | 2450.0 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | A | 0.705580511979575 | 0.7326578377282789 | 0.7612766064650269 | 0.790763773628808 | 0.8206035917883197 | 0.8507091598684661 | 0.8809020071529161 | 0.9110709048791528 | 0.9411099554544028 | ... | 2.605407249936357 | 2.6063895690546035 | 2.6070578522133414 | 2.6057334305850905 | 2.605000334669502 | 2.6042814767891067 | 2.6038356737304618 | 2.603928861322881 | 2.6048270326597267 | 2.6067961589745665 |
1 | A | 0.7543212749301422 | 0.7710331551024008 | 0.794609162338694 | 0.8209888707426471 | 0.8468705491698457 | 0.8720104059298207 | 0.8969056664336702 | 0.9220026220704322 | 0.9477351002160105 | ... | 2.200904117908356 | 2.205464986153305 | 2.2115637743242025 | 2.2118412622746746 | 2.215309051608229 | 2.2187902569306326 | 2.2221600828141983 | 2.225281601120992 | 2.2280178677694655 | 2.230231876101838 |
2 | A | 0.7637397260963806 | 0.7780492534053065 | 0.8002521804826961 | 0.8258716515810832 | 0.8511833980523711 | 0.8752433216064436 | 0.8983694405569923 | 0.9217725174326638 | 0.94665728036047 | ... | 2.0732455857656777 | 2.0751229179820574 | 2.0775157541381484 | 2.078554940400557 | 2.0809079774767447 | 2.083817880163292 | 2.087189007217301 | 2.0908761768045387 | 2.094734212847226 | 2.098617917080656 |
3 | A | 0.7738975011698456 | 0.7885199044747664 | 0.8104609751854193 | 0.8351439444306219 | 0.8588025014864068 | 0.8807442617571755 | 0.9017277150101468 | 0.9233332089256702 | 0.9471102273205482 | ... | 1.937158735691603 | 1.9393691914491984 | 1.9420789864601704 | 1.9457076027600417 | 1.950134131738519 | 1.9555722654659562 | 1.9618640516347374 | 1.9687842272095297 | 1.9761075265309989 | 1.9836086802985917 |
4 | A | 0.7563375306899326 | 0.7770924206584446 | 0.8001907701892301 | 0.823833745176814 | 0.8467578142153621 | 0.8695916932080284 | 0.8930824527035657 | 0.917847724924682 | 0.9444491125936656 | ... | 2.0743223212649915 | 2.0763653677692653 | 2.0785748724238395 | 2.082123774871022 | 2.085646598284524 | 2.08960955789617 | 2.093777868694627 | 2.0978615656955437 | 2.1015707158349075 | 2.1046153960238327 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
21995 | X | 0.9352696125079532 | 0.9428596301600836 | 0.9540222528029001 | 0.9664716766270952 | 0.9780333038463478 | 0.9867203231201302 | 0.9924311608152832 | 0.9957340067082262 | 0.9974785762425892 | ... | 1.721042477022407 | 1.722614402092429 | 1.7240860694899345 | 1.725510096014843 | 1.7269391593810006 | 1.7284056969204014 | 1.7299219023392576 | 1.7314991074464627 | 1.7331486375878389 | 1.7348774878656745 |
21996 | X | 1.01839230087239 | 1.0082193893104505 | 1.0035992201952535 | 1.0026919824190126 | 1.0031222623804033 | 1.0046887231165573 | 1.0053860240667778 | 1.004993555143038 | 1.003624977355195 | ... | 1.4566814198609133 | 1.4579935704996116 | 1.4591465475145518 | 1.460140470839626 | 1.4609745620089734 | 1.461639346814181 | 1.462116680079996 | 1.462388023544818 | 1.4624348548964001 | 1.4622409902706912 |
21997 | X | 0.9485436316078056 | 0.9529741538793863 | 0.9576646382182891 | 0.9626506818592507 | 0.967793809856022 | 0.9736716213173331 | 0.9795531171084139 | 0.9851950979312695 | 0.9905262438886581 | ... | 1.5540567458541845 | 1.5558885833471405 | 1.5573905910749701 | 1.5585295371378336 | 1.5592726465283473 | 1.5595962137855905 | 1.5594856199893876 | 1.5589266396795327 | 1.557905036446503 | 1.5564050708327737 |
21998 | X | 0.8901368475123629 | 0.9167001040797195 | 0.9385639144201832 | 0.9555572027945662 | 0.9675228722381615 | 0.9745375588184017 | 0.9787498403896902 | 0.9825660630780803 | 0.9869132002082837 | ... | 1.4148131758852065 | 1.4161437075661685 | 1.4174508624879245 | 1.418696303056643 | 1.4198412027358978 | 1.4208416611712187 | 1.4216486847396312 | 1.4222130626837708 | 1.422485588352071 | 1.4224156954934206 |
21999 | X | 0.9463348150783925 | 0.9441194489542596 | 0.9414787352728049 | 0.9397497567217487 | 0.9405898310082094 | 0.9439223495558785 | 0.9510582164011503 | 0.9615533586896492 | 0.9740879621503112 | ... | 1.3950699387230134 | 1.3950505268487752 | 1.394778368545595 | 1.3942186633809017 | 1.39333757049001 | 1.3921276890617873 | 1.3906080291056508 | 1.3887987584209467 | 1.3867200351033229 | 1.3843922257656194 |
22000 rows × 201 columns
Automatic classification with logistic regression¶
Next, we will use our simulated data to train a machine learning model to assign taxonomic classes.
Read in the simulated asteroid spectral data.
astsimdf = pd.read_hdf('astdata-simulated.h5')
astsimdf
tax | 450.0 | 460.0 | 470.0 | 480.0 | 490.0 | 500.0 | 510.0 | 520.0 | 530.0 | ... | 2360.0 | 2370.0 | 2380.0 | 2390.0 | 2400.0 | 2410.0 | 2420.0 | 2430.0 | 2440.0 | 2450.0 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | A | 0.705580511979575 | 0.7326578377282789 | 0.7612766064650269 | 0.790763773628808 | 0.8206035917883197 | 0.8507091598684661 | 0.8809020071529161 | 0.9110709048791528 | 0.9411099554544028 | ... | 2.605407249936357 | 2.6063895690546035 | 2.6070578522133414 | 2.6057334305850905 | 2.605000334669502 | 2.6042814767891067 | 2.6038356737304618 | 2.603928861322881 | 2.6048270326597267 | 2.6067961589745665 |
1 | A | 0.7543212749301422 | 0.7710331551024008 | 0.794609162338694 | 0.8209888707426471 | 0.8468705491698457 | 0.8720104059298207 | 0.8969056664336702 | 0.9220026220704322 | 0.9477351002160105 | ... | 2.200904117908356 | 2.205464986153305 | 2.2115637743242025 | 2.2118412622746746 | 2.215309051608229 | 2.2187902569306326 | 2.2221600828141983 | 2.225281601120992 | 2.2280178677694655 | 2.230231876101838 |
2 | A | 0.7637397260963806 | 0.7780492534053065 | 0.8002521804826961 | 0.8258716515810832 | 0.8511833980523711 | 0.8752433216064436 | 0.8983694405569923 | 0.9217725174326638 | 0.94665728036047 | ... | 2.0732455857656777 | 2.0751229179820574 | 2.0775157541381484 | 2.078554940400557 | 2.0809079774767447 | 2.083817880163292 | 2.087189007217301 | 2.0908761768045387 | 2.094734212847226 | 2.098617917080656 |
3 | A | 0.7738975011698456 | 0.7885199044747664 | 0.8104609751854193 | 0.8351439444306219 | 0.8588025014864068 | 0.8807442617571755 | 0.9017277150101468 | 0.9233332089256702 | 0.9471102273205482 | ... | 1.937158735691603 | 1.9393691914491984 | 1.9420789864601704 | 1.9457076027600417 | 1.950134131738519 | 1.9555722654659562 | 1.9618640516347374 | 1.9687842272095297 | 1.9761075265309989 | 1.9836086802985917 |
4 | A | 0.7563375306899326 | 0.7770924206584446 | 0.8001907701892301 | 0.823833745176814 | 0.8467578142153621 | 0.8695916932080284 | 0.8930824527035657 | 0.917847724924682 | 0.9444491125936656 | ... | 2.0743223212649915 | 2.0763653677692653 | 2.0785748724238395 | 2.082123774871022 | 2.085646598284524 | 2.08960955789617 | 2.093777868694627 | 2.0978615656955437 | 2.1015707158349075 | 2.1046153960238327 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
21995 | X | 0.9352696125079532 | 0.9428596301600836 | 0.9540222528029001 | 0.9664716766270952 | 0.9780333038463478 | 0.9867203231201302 | 0.9924311608152832 | 0.9957340067082262 | 0.9974785762425892 | ... | 1.721042477022407 | 1.722614402092429 | 1.7240860694899345 | 1.725510096014843 | 1.7269391593810006 | 1.7284056969204014 | 1.7299219023392576 | 1.7314991074464627 | 1.7331486375878389 | 1.7348774878656745 |
21996 | X | 1.01839230087239 | 1.0082193893104505 | 1.0035992201952535 | 1.0026919824190126 | 1.0031222623804033 | 1.0046887231165573 | 1.0053860240667778 | 1.004993555143038 | 1.003624977355195 | ... | 1.4566814198609133 | 1.4579935704996116 | 1.4591465475145518 | 1.460140470839626 | 1.4609745620089734 | 1.461639346814181 | 1.462116680079996 | 1.462388023544818 | 1.4624348548964001 | 1.4622409902706912 |
21997 | X | 0.9485436316078056 | 0.9529741538793863 | 0.9576646382182891 | 0.9626506818592507 | 0.967793809856022 | 0.9736716213173331 | 0.9795531171084139 | 0.9851950979312695 | 0.9905262438886581 | ... | 1.5540567458541845 | 1.5558885833471405 | 1.5573905910749701 | 1.5585295371378336 | 1.5592726465283473 | 1.5595962137855905 | 1.5594856199893876 | 1.5589266396795327 | 1.557905036446503 | 1.5564050708327737 |
21998 | X | 0.8901368475123629 | 0.9167001040797195 | 0.9385639144201832 | 0.9555572027945662 | 0.9675228722381615 | 0.9745375588184017 | 0.9787498403896902 | 0.9825660630780803 | 0.9869132002082837 | ... | 1.4148131758852065 | 1.4161437075661685 | 1.4174508624879245 | 1.418696303056643 | 1.4198412027358978 | 1.4208416611712187 | 1.4216486847396312 | 1.4222130626837708 | 1.422485588352071 | 1.4224156954934206 |
21999 | X | 0.9463348150783925 | 0.9441194489542596 | 0.9414787352728049 | 0.9397497567217487 | 0.9405898310082094 | 0.9439223495558785 | 0.9510582164011503 | 0.9615533586896492 | 0.9740879621503112 | ... | 1.3950699387230134 | 1.3950505268487752 | 1.394778368545595 | 1.3942186633809017 | 1.39333757049001 | 1.3921276890617873 | 1.3906080291056508 | 1.3887987584209467 | 1.3867200351033229 | 1.3843922257656194 |
22000 rows × 201 columns
Let's look at the class distribution in our simulated data.
n = len(astsimdf['tax'].unique().tolist())
p = sns.color_palette("husl", n)
sns.countplot(astsimdf, x="tax", hue='tax', palette=p);
As expected, the number of elements for each taxonomy is 2000. This is important for our classification accuracy, to ensure all classes are equally represented during training and evaluation.
Let's select the input and output of our model.
X, y = astsimdf.drop(['tax'], axis=1), astsimdf['tax']
Next, we split our data into train and test split. We will use 66% of our data for training, and the rest for evaluation. The split method from sklearn will also shuffle the data, as we don't want all elements of the same taxonomy next to each other.
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
Let's look at our evaluation set.
class_counts = Counter(y_test)
plt.bar(class_counts.keys(), class_counts.values());
We can see that the distribution of classes in out evaluation set is close to the original data. This will ensure the metrics we calculate are representative of the model performance.
Let's train a Logistic Regression model to predict our taxonomy.
clf = LogisticRegression(random_state=0).fit(X_train, y_train)
And use the trained model to predict classes on the evaluation set.
y_pred = clf.predict(X_test)
Now, we can check how our model performs on this problem. We calculate the accuracy score between real and assigned labels. We obtain around 85% accuracy on 11 taxonomic classes. This could be further impoved by tuning the parameters, or using a more complex model to better represent the relationships in data.
accuracy_score(y_test, y_pred)
0.8495867768595041
We can also study the performance of the classification for individual categories using confusion matrix.
cm = confusion_matrix(y_test,y_pred)
cmplt = ConfusionMatrixDisplay(cm,display_labels=clf.classes_)
cmplt.plot();