Machine Learning for Small Bodies in the Solar System - Supplemental Material
  • Home
  • Chapter 2, Part 1. Identification of Asteroid Families’ Members
  • Chapter 2, Part 2. Use of Machine Learning and Genetic Algorithms
  • Example of a notebook
  • Chapter 4. CNN for images
  • Chapter 6. Asteroid spectral classification
    • Simulate more spectra from existing samples using PCA
      • Data -> PCA
      • Random sample based on PCA representation of the real spectra
      • PCA -> data
    • Automatic classification with logistic regression
  • Chapter 7. Machine Learning Assisted Dynamical Classification of Trans-Neptunian Objects
  • MOPS magnitude estimator using a Convolutional Neural Network
Machine Learning for Small Bodies in the Solar System - Supplemental Material
  • »
  • Chapter 6. Asteroid spectral classification

Chapter 6. Asteroid spectral classification¶

Additional material to chapter 6: Asteroid spectro-photometric characterisation in eds. Valerio Carruba, Evgeny Smirnov, Dagmara Anna Oszkiewicz, Machine Learning for Small Bodies in the Solar System, Elsevier 2024.

Authors of chapter 6 and the following material:
Dagmara Oszkiewicz1, Antti Penttilä2, and Hanna Klimczak-Plucinska1

1Astronomical Observatory Institute, Faculty of Physics, Adam Mickiewicz University, Sªoneczna 36, 60-286 Poznan, Poland
2Department of Physics, University of Helsinki, Finland
03-04-2024

We need the following Python packages.

In [1]:
Copied!
# For data simulation with PCA
import pandas as pd
import numpy as np
import scipy.linalg
%matplotlib inline
from matplotlib import pyplot as plt
from prettytable import PrettyTable
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import warnings
warnings.filterwarnings("ignore")
# For data simulation with PCA import pandas as pd import numpy as np import scipy.linalg %matplotlib inline from matplotlib import pyplot as plt from prettytable import PrettyTable from IPython.core.interactiveshell import InteractiveShell InteractiveShell.ast_node_interactivity = "all" import warnings warnings.filterwarnings("ignore")
In [2]:
Copied!
# For classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
from collections import Counter
# For classification from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split import seaborn as sns from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay from collections import Counter

We will use our own collection of asteroid spectral data and taxonomic labels that is based on the SMASSII data and the MITHNEOS data, but we have joined them and simplified the taxonomic labels. First, let's create the wavelengths of the data. There is no wavelength 550 nm in the data, since it is normalized there so the value would always be 1.

In [3]:
Copied!
temp1 = np.arange(450,550,10)
temp2 = np.arange(560,2455,10)
wavelengths = np.concatenate([temp1,temp2])
len(wavelengths)
temp1 = np.arange(450,550,10) temp2 = np.arange(560,2455,10) wavelengths = np.concatenate([temp1,temp2]) len(wavelengths)
Out[3]:
200

Then, read the data in as a Pandas dataframe. The wavelengths are used as column labels, togehter with the label 'tax' for taxonomy.

In [4]:
Copied!
astdata = pd.read_csv('asteroid-spectral-data.csv',header=None,names=np.concatenate([['tax'],wavelengths]))
astdata
astdata = pd.read_csv('asteroid-spectral-data.csv',header=None,names=np.concatenate([['tax'],wavelengths])) astdata
Out[4]:
tax 450 460 470 480 490 500 510 520 530 ... 2360 2370 2380 2390 2400 2410 2420 2430 2440 2450
0 C 0.936682 0.948220 0.958273 0.966964 0.974415 0.980751 0.986094 0.990568 0.994294 ... 1.002864 1.003190 1.003550 1.003946 1.004382 1.004860 1.005383 1.005955 1.006578 1.007254
1 B 0.972303 0.977751 0.982461 0.986487 0.989882 0.992699 0.994993 0.996816 0.998223 ... 0.836224 0.835711 0.835245 0.834830 0.834470 0.834168 0.833927 0.833752 0.833646 0.833612
2 S 0.882684 0.895862 0.910649 0.924991 0.937650 0.948964 0.959456 0.969644 0.979917 ... 1.256736 1.257612 1.258852 1.260557 1.262824 1.265675 1.268821 1.271896 1.274532 1.276362
3 V 0.901729 0.912059 0.924331 0.937435 0.950260 0.961788 0.971787 0.980430 0.987896 ... 1.118528 1.125348 1.132181 1.138816 1.145046 1.150661 1.155450 1.159206 1.161719 1.162779
4 S 0.857128 0.876266 0.892801 0.907897 0.922509 0.936765 0.950582 0.963880 0.976577 ... 1.287134 1.288773 1.290575 1.292559 1.294742 1.297205 1.300269 1.304318 1.309735 1.316905
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
581 L 0.964614 0.956561 0.951836 0.950159 0.951247 0.954816 0.960585 0.968270 0.977589 ... 1.707470 1.715092 1.723149 1.731653 1.740618 1.750054 1.759973 1.770388 1.781310 1.792751
582 S 0.778583 0.818356 0.852047 0.880469 0.904433 0.924750 0.942231 0.957689 0.971935 ... 1.384155 1.391469 1.398232 1.404404 1.409946 1.414820 1.418988 1.422410 1.425049 1.426865
583 X 1.041278 1.028260 1.018052 1.010335 1.004792 1.001103 0.998951 0.998016 0.998013 ... 1.226848 1.229681 1.232610 1.235639 1.238770 1.242007 1.245353 1.248810 1.252381 1.256070
584 Q 0.853625 0.864378 0.876155 0.889071 0.903245 0.918794 0.935665 0.953124 0.970268 ... 0.904400 0.903679 0.903112 0.902801 0.902851 0.903363 0.904442 0.906189 0.908709 0.912104
585 L 0.912581 0.924933 0.936311 0.946782 0.956412 0.965266 0.973412 0.980913 0.987822 ... 1.357655 1.360819 1.364037 1.367312 1.370643 1.374032 1.377481 1.380989 1.384559 1.388191

586 rows × 201 columns

We can see how the spectra looks like for the 5 first asteroids.

In [5]:
Copied!
temp = astdata.iloc[0:5,1:].to_numpy()
plt.plot(wavelengths, temp.transpose());
plt.xlabel('wavelength');
plt.ylabel('reflectance');
plt.axhline(1, linestyle='--', color='k')
plt.show();
temp = astdata.iloc[0:5,1:].to_numpy() plt.plot(wavelengths, temp.transpose()); plt.xlabel('wavelength'); plt.ylabel('reflectance'); plt.axhline(1, linestyle='--', color='k') plt.show();

Save pandas dataframe for faster access.

In [8]:
Copied!
astdata.to_hdf('astdata.h5','astdata',mode='w')
astdata.to_hdf('astdata.h5','astdata',mode='w')

Simulate more spectra from existing samples using PCA¶

We will show how to create simulated samples from spectral taxonomic groups using conversion data -> PCA presentation -> data.

First, we will load the asteroid spectral data as pandas dataframe.

In [3]:
Copied!
astdata = pd.read_hdf('astdata.h5')
astdata
astdata = pd.read_hdf('astdata.h5') astdata
Out[3]:
tax 450 460 470 480 490 500 510 520 530 ... 2360 2370 2380 2390 2400 2410 2420 2430 2440 2450
0 C 0.936682 0.948220 0.958273 0.966964 0.974415 0.980751 0.986094 0.990568 0.994294 ... 1.002864 1.003190 1.003550 1.003946 1.004382 1.004860 1.005383 1.005955 1.006578 1.007254
1 B 0.972303 0.977751 0.982461 0.986487 0.989882 0.992699 0.994993 0.996816 0.998223 ... 0.836224 0.835711 0.835245 0.834830 0.834470 0.834168 0.833927 0.833752 0.833646 0.833612
2 S 0.882684 0.895862 0.910649 0.924991 0.937650 0.948964 0.959456 0.969644 0.979917 ... 1.256736 1.257612 1.258852 1.260557 1.262824 1.265675 1.268821 1.271896 1.274532 1.276362
3 V 0.901729 0.912059 0.924331 0.937435 0.950260 0.961788 0.971787 0.980430 0.987896 ... 1.118528 1.125348 1.132181 1.138816 1.145046 1.150661 1.155450 1.159206 1.161719 1.162779
4 S 0.857128 0.876266 0.892801 0.907897 0.922509 0.936765 0.950582 0.963880 0.976577 ... 1.287134 1.288773 1.290575 1.292559 1.294742 1.297205 1.300269 1.304318 1.309735 1.316905
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
581 L 0.964614 0.956561 0.951836 0.950159 0.951247 0.954816 0.960585 0.968270 0.977589 ... 1.707470 1.715092 1.723149 1.731653 1.740618 1.750054 1.759973 1.770388 1.781310 1.792751
582 S 0.778583 0.818356 0.852047 0.880469 0.904433 0.924750 0.942231 0.957689 0.971935 ... 1.384155 1.391469 1.398232 1.404404 1.409946 1.414820 1.418988 1.422410 1.425049 1.426865
583 X 1.041278 1.028260 1.018052 1.010335 1.004792 1.001103 0.998951 0.998016 0.998013 ... 1.226848 1.229681 1.232610 1.235639 1.238770 1.242007 1.245353 1.248810 1.252381 1.256070
584 Q 0.853625 0.864378 0.876155 0.889071 0.903245 0.918794 0.935665 0.953124 0.970268 ... 0.904400 0.903679 0.903112 0.902801 0.902851 0.903363 0.904442 0.906189 0.908709 0.912104
585 L 0.912581 0.924933 0.936311 0.946782 0.956412 0.965266 0.973412 0.980913 0.987822 ... 1.357655 1.360819 1.364037 1.367312 1.370643 1.374032 1.377481 1.380989 1.384559 1.388191

586 rows × 201 columns

Get wawelengths from column labels.

In [4]:
Copied!
wavelengths = astdata.columns.to_numpy()[1:].astype('float')
len(wavelengths)
wavelengths = astdata.columns.to_numpy()[1:].astype('float') len(wavelengths)
Out[4]:
200

Group data by taxonomic label.

In [5]:
Copied!
astgrdata = astdata.groupby('tax')
astgrdata = astdata.groupby('tax')

How many groups, what are the groups, and how many asteroids per group?

In [6]:
Copied!
astgrdata.ngroups
astgrdata.size()
astgrdata.ngroups astgrdata.size()
Out[6]:
11
Out[6]:
tax
A      7
B     12
C     61
D     22
K     15
L     33
Q     43
S    310
T      4
V     28
X     51
dtype: int64

Data -> PCA¶

We will do PCA decomposition 'by hand' for each taxonomic group since we will need to do inverse transform in the end.

In [7]:
Copied!
meanvecs = []
covmats = []
for tax, gr in astgrdata:
    print("Group "+tax)
    # Numerical data per group to numpy array
    tempd = gr.iloc[:,1:].to_numpy()
    print("Data shape: ", np.shape(tempd))
    # Mean vector
    mv = np.mean(tempd,axis=0)
    meanvecs.append(mv)
    # Center data
    tempdc = tempd-mv
    # Covariance matrix
    cm = np.cov(tempdc,rowvar=False)
    covmats.append(cm)
meanvecs = [] covmats = [] for tax, gr in astgrdata: print("Group "+tax) # Numerical data per group to numpy array tempd = gr.iloc[:,1:].to_numpy() print("Data shape: ", np.shape(tempd)) # Mean vector mv = np.mean(tempd,axis=0) meanvecs.append(mv) # Center data tempdc = tempd-mv # Covariance matrix cm = np.cov(tempdc,rowvar=False) covmats.append(cm)
Group A
Data shape:  (7, 200)
Group B
Data shape:  (12, 200)
Group C
Data shape:  (61, 200)
Group D
Data shape:  (22, 200)
Group K
Data shape:  (15, 200)
Group L
Data shape:  (33, 200)
Group Q
Data shape:  (43, 200)
Group S
Data shape:  (310, 200)
Group T
Data shape:  (4, 200)
Group V
Data shape:  (28, 200)
Group X
Data shape:  (51, 200)

Eigenvalue decompositions of covariance matrices.

In [8]:
Copied!
evalvecs = []
evecmats = []
for cm in covmats:
    eval, evec = scipy.linalg.eig(cm)
    evalvecs.append(np.real(eval))
    evecmats.append(np.real(evec))
evalvecs = [] evecmats = [] for cm in covmats: eval, evec = scipy.linalg.eig(cm) evalvecs.append(np.real(eval)) evecmats.append(np.real(evec))

Table showing the cumulative variance of the 7 first PCA components for each taxonomy group

In [9]:
Copied!
tab = PrettyTable(["Taxonomy", "PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6", "PCA7"])
taxlabs = astgrdata.groups.keys()
for tax,ev in zip(taxlabs,evalvecs):
    t = np.concatenate([[tax],np.cumsum(ev[:7])/np.sum(ev)]).tolist()
    tab.add_row(t)
tab
tab = PrettyTable(["Taxonomy", "PCA1", "PCA2", "PCA3", "PCA4", "PCA5", "PCA6", "PCA7"]) taxlabs = astgrdata.groups.keys() for tax,ev in zip(taxlabs,evalvecs): t = np.concatenate([[tax],np.cumsum(ev[:7])/np.sum(ev)]).tolist() tab.add_row(t) tab
Out[9]:
Taxonomy PCA1 PCA2 PCA3 PCA4 PCA5 PCA6 PCA7
A 0.907769827275844 0.9739629506041639 0.9942138192058917 0.9980044727500169 0.9997828516127344 0.9999999999999998 0.9999999999999998
B 0.8342119769869865 0.9387833423487731 0.9698329605476764 0.9877615261335111 0.9935758226011708 0.9968089508541326 0.9986770808287602
C 0.8668553650371404 0.9632469049503382 0.9820141602629984 0.9891803028057304 0.9939147684161667 0.9966029640946612 0.9981533007516981
D 0.9673534684251127 0.9869047313519237 0.995437080509162 0.9976094137287717 0.9984669336988631 0.99917603309234 0.999594921724622
K 0.8899257263383968 0.9665352471352603 0.9870965800701219 0.9937501861112938 0.9977499971284651 0.9985772819520006 0.9991180406917984
L 0.8917635488472414 0.960042364799775 0.9786187519368502 0.9904457258343683 0.9951043445548811 0.9974945839322459 0.9982666235529828
Q 0.8512098199202154 0.9237449482182472 0.9474411396455018 0.9640414404744116 0.9766000105110583 0.9846290981168039 0.9888921500439156
S 0.9030786044395169 0.9626088715535716 0.973664213112867 0.9826650260033087 0.9895810814477778 0.9932433722309181 0.9950313819120943
T 0.9379045243965995 0.9777699979013184 0.9999999999999999 0.9999999999999999 0.9999999999999999 0.9999999999999999 0.9999999999999999
V 0.8909842719933525 0.9452095591414079 0.9681707026406133 0.9847180543448252 0.9934305834527316 0.996248249354546 0.997158289389082
X 0.9189602907767475 0.9828782893103468 0.9899547264957423 0.994425611670312 0.9963740010065031 0.9978561233486718 0.9985249808374639

Random sample based on PCA representation of the real spectra¶

Example on how to simulate spectral sample based on the PCA representation of the real samples.

Let's take one S-type spectra as an example.

In [10]:
Copied!
one = astgrdata.get_group('S').iloc[0,1:].to_numpy()
plt.plot(wavelengths,one);
plt.xlabel('wavelength');
plt.ylabel('reflectance');
plt.axhline(1, linestyle='--', color='k')
plt.show();
one = astgrdata.get_group('S').iloc[0,1:].to_numpy() plt.plot(wavelengths,one); plt.xlabel('wavelength'); plt.ylabel('reflectance'); plt.axhline(1, linestyle='--', color='k') plt.show();

Convert to PCA presentation (S is the index 7 label in grouped data).

In [11]:
Copied!
onepca = (one-meanvecs[7]) @ evecmats[7]
onepca = (one-meanvecs[7]) @ evecmats[7]

Simulate random noise to components according to their eigenvalues. A scaling value of 0.6 to eigenvalues (variances) is applied.

In [12]:
Copied!
rn = np.random.normal(scale=0.6*np.sqrt(np.clip(evalvecs[7],0,None)))
rn = np.random.normal(scale=0.6*np.sqrt(np.clip(evalvecs[7],0,None)))

Add noise and transform back from PCA space.

In [13]:
Copied!
onepca1 = onepca + rn
one1 = onepca1 @ evecmats[7].T + meanvecs[7]
onepca1 = onepca + rn one1 = onepca1 @ evecmats[7].T + meanvecs[7]

Show orignal and random sample based on original.

In [14]:
Copied!
plt.plot(wavelengths,one,wavelengths,one1);
plt.xlabel('wavelength');
plt.ylabel('reflectance');
plt.axhline(1, linestyle='--', color='k')
plt.show();
plt.plot(wavelengths,one,wavelengths,one1); plt.xlabel('wavelength'); plt.ylabel('reflectance'); plt.axhline(1, linestyle='--', color='k') plt.show();

PCA -> data¶

Let's do the above example to all taxonomic groups and simulate 2000 samples for each taxonomy.

In [15]:
Copied!
k = 2000
noisescale = 0.6
astsimdata = []
taxlab = list(astgrdata.groups.keys())
for i in range(astgrdata.ngroups):
    print("Group "+taxlab[i])
    d1 = astgrdata.get_group(taxlab[i]).iloc[:,1:].to_numpy()
    print(str(len(d1))+" asteroids in group")
    # base the simulated sample on random sample of the real spectra.
    ranind = np.random.randint(0,high=len(d1),size=k)
    sample = np.zeros((k,len(d1[0])))
    for j in range(k):
        basespectra = d1[ranind[j]]
        # PCA representation
        pca = (basespectra-meanvecs[i]) @ evecmats[i]
        # Random change to PCA components
        rannoise = np.random.normal(scale=noisescale*np.sqrt(np.clip(evalvecs[i],0,None)))
        # Back to spectra
        simspectra = (pca + rannoise) @ evecmats[i].T + meanvecs[i]
        sample[j] = simspectra
    astsimdata.append(sample)
k = 2000 noisescale = 0.6 astsimdata = [] taxlab = list(astgrdata.groups.keys()) for i in range(astgrdata.ngroups): print("Group "+taxlab[i]) d1 = astgrdata.get_group(taxlab[i]).iloc[:,1:].to_numpy() print(str(len(d1))+" asteroids in group") # base the simulated sample on random sample of the real spectra. ranind = np.random.randint(0,high=len(d1),size=k) sample = np.zeros((k,len(d1[0]))) for j in range(k): basespectra = d1[ranind[j]] # PCA representation pca = (basespectra-meanvecs[i]) @ evecmats[i] # Random change to PCA components rannoise = np.random.normal(scale=noisescale*np.sqrt(np.clip(evalvecs[i],0,None))) # Back to spectra simspectra = (pca + rannoise) @ evecmats[i].T + meanvecs[i] sample[j] = simspectra astsimdata.append(sample)
Group A
7 asteroids in group
Group B
12 asteroids in group
Group C
61 asteroids in group
Group D
22 asteroids in group
Group K
15 asteroids in group
Group L
33 asteroids in group
Group Q
43 asteroids in group
Group S
310 asteroids in group
Group T
4 asteroids in group
Group V
28 asteroids in group
Group X
51 asteroids in group

Plot the samples from each group.

In [16]:
Copied!
for i in range(astgrdata.ngroups):
    _ = plt.plot(wavelengths,astsimdata[i].T)
    _ = plt.xlabel('wavelength')
    _ = plt.ylabel('reflectance')
    _ = plt.axhline(1, linestyle='--', color='k')
    _ = plt.title(taxlab[i])
    plt.show()
for i in range(astgrdata.ngroups): _ = plt.plot(wavelengths,astsimdata[i].T) _ = plt.xlabel('wavelength') _ = plt.ylabel('reflectance') _ = plt.axhline(1, linestyle='--', color='k') _ = plt.title(taxlab[i]) plt.show()

Save the simulated sample.

In [17]:
Copied!
# Concatenate into one matrix.
onematrix = np.concatenate(astsimdata)
# Do vector of labels
labs = []
for tax in taxlab:
    labs.append(np.full((k,1),tax))
onelabs = np.concatenate(labs)
# Join labels and data
outmatrix = np.c_[onelabs, onematrix]
np.shape(outmatrix)
# Concatenate into one matrix. onematrix = np.concatenate(astsimdata) # Do vector of labels labs = [] for tax in taxlab: labs.append(np.full((k,1),tax)) onelabs = np.concatenate(labs) # Join labels and data outmatrix = np.c_[onelabs, onematrix] np.shape(outmatrix)
Out[17]:
(22000, 201)

Save simulated samples as hdf5 file.

In [18]:
Copied!
astsimdf = pd.DataFrame(outmatrix,columns=np.concatenate([['tax'],wavelengths]))
astsimdf.to_hdf('astdata-simulated.h5','astsimdata',mode='w')
astsimdf
astsimdf = pd.DataFrame(outmatrix,columns=np.concatenate([['tax'],wavelengths])) astsimdf.to_hdf('astdata-simulated.h5','astsimdata',mode='w') astsimdf
Out[18]:
tax 450.0 460.0 470.0 480.0 490.0 500.0 510.0 520.0 530.0 ... 2360.0 2370.0 2380.0 2390.0 2400.0 2410.0 2420.0 2430.0 2440.0 2450.0
0 A 0.705580511979575 0.7326578377282789 0.7612766064650269 0.790763773628808 0.8206035917883197 0.8507091598684661 0.8809020071529161 0.9110709048791528 0.9411099554544028 ... 2.605407249936357 2.6063895690546035 2.6070578522133414 2.6057334305850905 2.605000334669502 2.6042814767891067 2.6038356737304618 2.603928861322881 2.6048270326597267 2.6067961589745665
1 A 0.7543212749301422 0.7710331551024008 0.794609162338694 0.8209888707426471 0.8468705491698457 0.8720104059298207 0.8969056664336702 0.9220026220704322 0.9477351002160105 ... 2.200904117908356 2.205464986153305 2.2115637743242025 2.2118412622746746 2.215309051608229 2.2187902569306326 2.2221600828141983 2.225281601120992 2.2280178677694655 2.230231876101838
2 A 0.7637397260963806 0.7780492534053065 0.8002521804826961 0.8258716515810832 0.8511833980523711 0.8752433216064436 0.8983694405569923 0.9217725174326638 0.94665728036047 ... 2.0732455857656777 2.0751229179820574 2.0775157541381484 2.078554940400557 2.0809079774767447 2.083817880163292 2.087189007217301 2.0908761768045387 2.094734212847226 2.098617917080656
3 A 0.7738975011698456 0.7885199044747664 0.8104609751854193 0.8351439444306219 0.8588025014864068 0.8807442617571755 0.9017277150101468 0.9233332089256702 0.9471102273205482 ... 1.937158735691603 1.9393691914491984 1.9420789864601704 1.9457076027600417 1.950134131738519 1.9555722654659562 1.9618640516347374 1.9687842272095297 1.9761075265309989 1.9836086802985917
4 A 0.7563375306899326 0.7770924206584446 0.8001907701892301 0.823833745176814 0.8467578142153621 0.8695916932080284 0.8930824527035657 0.917847724924682 0.9444491125936656 ... 2.0743223212649915 2.0763653677692653 2.0785748724238395 2.082123774871022 2.085646598284524 2.08960955789617 2.093777868694627 2.0978615656955437 2.1015707158349075 2.1046153960238327
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
21995 X 0.9352696125079532 0.9428596301600836 0.9540222528029001 0.9664716766270952 0.9780333038463478 0.9867203231201302 0.9924311608152832 0.9957340067082262 0.9974785762425892 ... 1.721042477022407 1.722614402092429 1.7240860694899345 1.725510096014843 1.7269391593810006 1.7284056969204014 1.7299219023392576 1.7314991074464627 1.7331486375878389 1.7348774878656745
21996 X 1.01839230087239 1.0082193893104505 1.0035992201952535 1.0026919824190126 1.0031222623804033 1.0046887231165573 1.0053860240667778 1.004993555143038 1.003624977355195 ... 1.4566814198609133 1.4579935704996116 1.4591465475145518 1.460140470839626 1.4609745620089734 1.461639346814181 1.462116680079996 1.462388023544818 1.4624348548964001 1.4622409902706912
21997 X 0.9485436316078056 0.9529741538793863 0.9576646382182891 0.9626506818592507 0.967793809856022 0.9736716213173331 0.9795531171084139 0.9851950979312695 0.9905262438886581 ... 1.5540567458541845 1.5558885833471405 1.5573905910749701 1.5585295371378336 1.5592726465283473 1.5595962137855905 1.5594856199893876 1.5589266396795327 1.557905036446503 1.5564050708327737
21998 X 0.8901368475123629 0.9167001040797195 0.9385639144201832 0.9555572027945662 0.9675228722381615 0.9745375588184017 0.9787498403896902 0.9825660630780803 0.9869132002082837 ... 1.4148131758852065 1.4161437075661685 1.4174508624879245 1.418696303056643 1.4198412027358978 1.4208416611712187 1.4216486847396312 1.4222130626837708 1.422485588352071 1.4224156954934206
21999 X 0.9463348150783925 0.9441194489542596 0.9414787352728049 0.9397497567217487 0.9405898310082094 0.9439223495558785 0.9510582164011503 0.9615533586896492 0.9740879621503112 ... 1.3950699387230134 1.3950505268487752 1.394778368545595 1.3942186633809017 1.39333757049001 1.3921276890617873 1.3906080291056508 1.3887987584209467 1.3867200351033229 1.3843922257656194

22000 rows × 201 columns

Automatic classification with logistic regression¶

Next, we will use our simulated data to train a machine learning model to assign taxonomic classes.

Read in the simulated asteroid spectral data.

In [19]:
Copied!
astsimdf = pd.read_hdf('astdata-simulated.h5')
astsimdf
astsimdf = pd.read_hdf('astdata-simulated.h5') astsimdf
Out[19]:
tax 450.0 460.0 470.0 480.0 490.0 500.0 510.0 520.0 530.0 ... 2360.0 2370.0 2380.0 2390.0 2400.0 2410.0 2420.0 2430.0 2440.0 2450.0
0 A 0.705580511979575 0.7326578377282789 0.7612766064650269 0.790763773628808 0.8206035917883197 0.8507091598684661 0.8809020071529161 0.9110709048791528 0.9411099554544028 ... 2.605407249936357 2.6063895690546035 2.6070578522133414 2.6057334305850905 2.605000334669502 2.6042814767891067 2.6038356737304618 2.603928861322881 2.6048270326597267 2.6067961589745665
1 A 0.7543212749301422 0.7710331551024008 0.794609162338694 0.8209888707426471 0.8468705491698457 0.8720104059298207 0.8969056664336702 0.9220026220704322 0.9477351002160105 ... 2.200904117908356 2.205464986153305 2.2115637743242025 2.2118412622746746 2.215309051608229 2.2187902569306326 2.2221600828141983 2.225281601120992 2.2280178677694655 2.230231876101838
2 A 0.7637397260963806 0.7780492534053065 0.8002521804826961 0.8258716515810832 0.8511833980523711 0.8752433216064436 0.8983694405569923 0.9217725174326638 0.94665728036047 ... 2.0732455857656777 2.0751229179820574 2.0775157541381484 2.078554940400557 2.0809079774767447 2.083817880163292 2.087189007217301 2.0908761768045387 2.094734212847226 2.098617917080656
3 A 0.7738975011698456 0.7885199044747664 0.8104609751854193 0.8351439444306219 0.8588025014864068 0.8807442617571755 0.9017277150101468 0.9233332089256702 0.9471102273205482 ... 1.937158735691603 1.9393691914491984 1.9420789864601704 1.9457076027600417 1.950134131738519 1.9555722654659562 1.9618640516347374 1.9687842272095297 1.9761075265309989 1.9836086802985917
4 A 0.7563375306899326 0.7770924206584446 0.8001907701892301 0.823833745176814 0.8467578142153621 0.8695916932080284 0.8930824527035657 0.917847724924682 0.9444491125936656 ... 2.0743223212649915 2.0763653677692653 2.0785748724238395 2.082123774871022 2.085646598284524 2.08960955789617 2.093777868694627 2.0978615656955437 2.1015707158349075 2.1046153960238327
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
21995 X 0.9352696125079532 0.9428596301600836 0.9540222528029001 0.9664716766270952 0.9780333038463478 0.9867203231201302 0.9924311608152832 0.9957340067082262 0.9974785762425892 ... 1.721042477022407 1.722614402092429 1.7240860694899345 1.725510096014843 1.7269391593810006 1.7284056969204014 1.7299219023392576 1.7314991074464627 1.7331486375878389 1.7348774878656745
21996 X 1.01839230087239 1.0082193893104505 1.0035992201952535 1.0026919824190126 1.0031222623804033 1.0046887231165573 1.0053860240667778 1.004993555143038 1.003624977355195 ... 1.4566814198609133 1.4579935704996116 1.4591465475145518 1.460140470839626 1.4609745620089734 1.461639346814181 1.462116680079996 1.462388023544818 1.4624348548964001 1.4622409902706912
21997 X 0.9485436316078056 0.9529741538793863 0.9576646382182891 0.9626506818592507 0.967793809856022 0.9736716213173331 0.9795531171084139 0.9851950979312695 0.9905262438886581 ... 1.5540567458541845 1.5558885833471405 1.5573905910749701 1.5585295371378336 1.5592726465283473 1.5595962137855905 1.5594856199893876 1.5589266396795327 1.557905036446503 1.5564050708327737
21998 X 0.8901368475123629 0.9167001040797195 0.9385639144201832 0.9555572027945662 0.9675228722381615 0.9745375588184017 0.9787498403896902 0.9825660630780803 0.9869132002082837 ... 1.4148131758852065 1.4161437075661685 1.4174508624879245 1.418696303056643 1.4198412027358978 1.4208416611712187 1.4216486847396312 1.4222130626837708 1.422485588352071 1.4224156954934206
21999 X 0.9463348150783925 0.9441194489542596 0.9414787352728049 0.9397497567217487 0.9405898310082094 0.9439223495558785 0.9510582164011503 0.9615533586896492 0.9740879621503112 ... 1.3950699387230134 1.3950505268487752 1.394778368545595 1.3942186633809017 1.39333757049001 1.3921276890617873 1.3906080291056508 1.3887987584209467 1.3867200351033229 1.3843922257656194

22000 rows × 201 columns

Let's look at the class distribution in our simulated data.

In [21]:
Copied!
n = len(astsimdf['tax'].unique().tolist())
p = sns.color_palette("husl", n)
sns.countplot(astsimdf, x="tax", hue='tax', palette=p);
n = len(astsimdf['tax'].unique().tolist()) p = sns.color_palette("husl", n) sns.countplot(astsimdf, x="tax", hue='tax', palette=p);

As expected, the number of elements for each taxonomy is 2000. This is important for our classification accuracy, to ensure all classes are equally represented during training and evaluation.

Let's select the input and output of our model.

In [22]:
Copied!
X, y = astsimdf.drop(['tax'], axis=1), astsimdf['tax']
X, y = astsimdf.drop(['tax'], axis=1), astsimdf['tax']

Next, we split our data into train and test split. We will use 66% of our data for training, and the rest for evaluation. The split method from sklearn will also shuffle the data, as we don't want all elements of the same taxonomy next to each other.

In [23]:
Copied!
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

Let's look at our evaluation set.

In [24]:
Copied!
class_counts = Counter(y_test)

plt.bar(class_counts.keys(), class_counts.values());
class_counts = Counter(y_test) plt.bar(class_counts.keys(), class_counts.values());

We can see that the distribution of classes in out evaluation set is close to the original data. This will ensure the metrics we calculate are representative of the model performance.

Let's train a Logistic Regression model to predict our taxonomy.

In [25]:
Copied!
clf = LogisticRegression(random_state=0).fit(X_train, y_train)
clf = LogisticRegression(random_state=0).fit(X_train, y_train)

And use the trained model to predict classes on the evaluation set.

In [26]:
Copied!
y_pred = clf.predict(X_test)
y_pred = clf.predict(X_test)

Now, we can check how our model performs on this problem. We calculate the accuracy score between real and assigned labels. We obtain around 85% accuracy on 11 taxonomic classes. This could be further impoved by tuning the parameters, or using a more complex model to better represent the relationships in data.

In [27]:
Copied!
accuracy_score(y_test, y_pred)
accuracy_score(y_test, y_pred)
Out[27]:
0.8495867768595041

We can also study the performance of the classification for individual categories using confusion matrix.

In [28]:
Copied!
cm = confusion_matrix(y_test,y_pred)
cm = confusion_matrix(y_test,y_pred)
In [29]:
Copied!
cmplt = ConfusionMatrixDisplay(cm,display_labels=clf.classes_)
cmplt.plot();
cmplt = ConfusionMatrixDisplay(cm,display_labels=clf.classes_) cmplt.plot();
Previous Next

Built with MkDocs using a theme provided by Read the Docs.
« Previous Next »