This project will focus on credit card fraud activity. Fraudulent detection is one of toughest challenges due to imbalanced data, irregular identifiable patterns, missing features, and live transactions. Creating a model with live streaming data, learn the live transaction data, update transaction pattern, and identify anomaly is pertinent in many areas.
It was reported that Federal Trade Commission received 2.8 million fraud reports from consumers in 2021. Consumers loss reached $5.8 billion which is 70% higher than 2020. Fraudsters are using more advanced techniques, such as machine learning, to target new customers, online transactions, and stealing identities. Currently, many models have been proposed to improve the fraud detection including KNN, logistic regression, SVM etc. For data preprocessing, data under-sampling, over-sampling, feature selection (PCA, logistic regression, SVM) have been widely used. There is report that credit card fraud detection recall can reach 0.94. However, based on the previous year’s report, fraudulent activities increase more and more. Fraudsters are using machine learning techniques to avoid defence machine learning algorithms. Simply label outliers or defining outliers are not satisfying the needs to identify attacking pattern. A platform that contains data streaming, data preprocessing (feature selection, auto-labeling, grouping), model selection, model training, model relearn based on live transaction data, and prediction is highly needed. The data-model live interaction will facilitate the model selection and updating, which will further enhance the anomaly detection speed.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
filepath = "Data/creditcard.csv"
df = pd.read_csv(filepath)
check the basic info of the dataset
df.head(5)
Time | V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | ... | V21 | V22 | V23 | V24 | V25 | V26 | V27 | V28 | Amount | Class | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | -1.359807 | -0.072781 | 2.536347 | 1.378155 | -0.338321 | 0.462388 | 0.239599 | 0.098698 | 0.363787 | ... | -0.018307 | 0.277838 | -0.110474 | 0.066928 | 0.128539 | -0.189115 | 0.133558 | -0.021053 | 149.62 | 0 |
1 | 0.0 | 1.191857 | 0.266151 | 0.166480 | 0.448154 | 0.060018 | -0.082361 | -0.078803 | 0.085102 | -0.255425 | ... | -0.225775 | -0.638672 | 0.101288 | -0.339846 | 0.167170 | 0.125895 | -0.008983 | 0.014724 | 2.69 | 0 |
2 | 1.0 | -1.358354 | -1.340163 | 1.773209 | 0.379780 | -0.503198 | 1.800499 | 0.791461 | 0.247676 | -1.514654 | ... | 0.247998 | 0.771679 | 0.909412 | -0.689281 | -0.327642 | -0.139097 | -0.055353 | -0.059752 | 378.66 | 0 |
3 | 1.0 | -0.966272 | -0.185226 | 1.792993 | -0.863291 | -0.010309 | 1.247203 | 0.237609 | 0.377436 | -1.387024 | ... | -0.108300 | 0.005274 | -0.190321 | -1.175575 | 0.647376 | -0.221929 | 0.062723 | 0.061458 | 123.50 | 0 |
4 | 2.0 | -1.158233 | 0.877737 | 1.548718 | 0.403034 | -0.407193 | 0.095921 | 0.592941 | -0.270533 | 0.817739 | ... | -0.009431 | 0.798278 | -0.137458 | 0.141267 | -0.206010 | 0.502292 | 0.219422 | 0.215153 | 69.99 | 0 |
5 rows × 31 columns
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 284807 entries, 0 to 284806 Data columns (total 31 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 Time 284807 non-null float64 1 V1 284807 non-null float64 2 V2 284807 non-null float64 3 V3 284807 non-null float64 4 V4 284807 non-null float64 5 V5 284807 non-null float64 6 V6 284807 non-null float64 7 V7 284807 non-null float64 8 V8 284807 non-null float64 9 V9 284807 non-null float64 10 V10 284807 non-null float64 11 V11 284807 non-null float64 12 V12 284807 non-null float64 13 V13 284807 non-null float64 14 V14 284807 non-null float64 15 V15 284807 non-null float64 16 V16 284807 non-null float64 17 V17 284807 non-null float64 18 V18 284807 non-null float64 19 V19 284807 non-null float64 20 V20 284807 non-null float64 21 V21 284807 non-null float64 22 V22 284807 non-null float64 23 V23 284807 non-null float64 24 V24 284807 non-null float64 25 V25 284807 non-null float64 26 V26 284807 non-null float64 27 V27 284807 non-null float64 28 V28 284807 non-null float64 29 Amount 284807 non-null float64 30 Class 284807 non-null int64 dtypes: float64(30), int64(1) memory usage: 67.4 MB
df.dtypes
Time float64 V1 float64 V2 float64 V3 float64 V4 float64 V5 float64 V6 float64 V7 float64 V8 float64 V9 float64 V10 float64 V11 float64 V12 float64 V13 float64 V14 float64 V15 float64 V16 float64 V17 float64 V18 float64 V19 float64 V20 float64 V21 float64 V22 float64 V23 float64 V24 float64 V25 float64 V26 float64 V27 float64 V28 float64 Amount float64 Class int64 dtype: object
df.describe()
Time | V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | ... | V21 | V22 | V23 | V24 | V25 | V26 | V27 | V28 | Amount | Class | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 284807.000000 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | ... | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 2.848070e+05 | 284807.000000 | 284807.000000 |
mean | 94813.859575 | 3.918649e-15 | 5.682686e-16 | -8.761736e-15 | 2.811118e-15 | -1.552103e-15 | 2.040130e-15 | -1.698953e-15 | -1.893285e-16 | -3.147640e-15 | ... | 1.473120e-16 | 8.042109e-16 | 5.282512e-16 | 4.456271e-15 | 1.426896e-15 | 1.701640e-15 | -3.662252e-16 | -1.217809e-16 | 88.349619 | 0.001727 |
std | 47488.145955 | 1.958696e+00 | 1.651309e+00 | 1.516255e+00 | 1.415869e+00 | 1.380247e+00 | 1.332271e+00 | 1.237094e+00 | 1.194353e+00 | 1.098632e+00 | ... | 7.345240e-01 | 7.257016e-01 | 6.244603e-01 | 6.056471e-01 | 5.212781e-01 | 4.822270e-01 | 4.036325e-01 | 3.300833e-01 | 250.120109 | 0.041527 |
min | 0.000000 | -5.640751e+01 | -7.271573e+01 | -4.832559e+01 | -5.683171e+00 | -1.137433e+02 | -2.616051e+01 | -4.355724e+01 | -7.321672e+01 | -1.343407e+01 | ... | -3.483038e+01 | -1.093314e+01 | -4.480774e+01 | -2.836627e+00 | -1.029540e+01 | -2.604551e+00 | -2.256568e+01 | -1.543008e+01 | 0.000000 | 0.000000 |
25% | 54201.500000 | -9.203734e-01 | -5.985499e-01 | -8.903648e-01 | -8.486401e-01 | -6.915971e-01 | -7.682956e-01 | -5.540759e-01 | -2.086297e-01 | -6.430976e-01 | ... | -2.283949e-01 | -5.423504e-01 | -1.618463e-01 | -3.545861e-01 | -3.171451e-01 | -3.269839e-01 | -7.083953e-02 | -5.295979e-02 | 5.600000 | 0.000000 |
50% | 84692.000000 | 1.810880e-02 | 6.548556e-02 | 1.798463e-01 | -1.984653e-02 | -5.433583e-02 | -2.741871e-01 | 4.010308e-02 | 2.235804e-02 | -5.142873e-02 | ... | -2.945017e-02 | 6.781943e-03 | -1.119293e-02 | 4.097606e-02 | 1.659350e-02 | -5.213911e-02 | 1.342146e-03 | 1.124383e-02 | 22.000000 | 0.000000 |
75% | 139320.500000 | 1.315642e+00 | 8.037239e-01 | 1.027196e+00 | 7.433413e-01 | 6.119264e-01 | 3.985649e-01 | 5.704361e-01 | 3.273459e-01 | 5.971390e-01 | ... | 1.863772e-01 | 5.285536e-01 | 1.476421e-01 | 4.395266e-01 | 3.507156e-01 | 2.409522e-01 | 9.104512e-02 | 7.827995e-02 | 77.165000 | 0.000000 |
max | 172792.000000 | 2.454930e+00 | 2.205773e+01 | 9.382558e+00 | 1.687534e+01 | 3.480167e+01 | 7.330163e+01 | 1.205895e+02 | 2.000721e+01 | 1.559499e+01 | ... | 2.720284e+01 | 1.050309e+01 | 2.252841e+01 | 4.584549e+00 | 7.519589e+00 | 3.517346e+00 | 3.161220e+01 | 3.384781e+01 | 25691.160000 | 1.000000 |
8 rows × 31 columns
df.isna().sum()
Time 0 V1 0 V2 0 V3 0 V4 0 V5 0 V6 0 V7 0 V8 0 V9 0 V10 0 V11 0 V12 0 V13 0 V14 0 V15 0 V16 0 V17 0 V18 0 V19 0 V20 0 V21 0 V22 0 V23 0 V24 0 V25 0 V26 0 V27 0 V28 0 Amount 0 Class 0 dtype: int64
Good news! There are no null values in the data set.
# How many fraud transactions are there? Class ==1
fraud_sum = (df['Class']==1).sum()
all_sum = df.shape[0]
print("There are ", all_sum, "transactions, and ", fraud_sum, " fraud transactions")
There are 284807 transactions, and 492 fraud transactions
fraud = df.loc[df['Class']==1]
normal = df.loc[df['Class']==0]
fraud.head(5)
Time | V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | ... | V21 | V22 | V23 | V24 | V25 | V26 | V27 | V28 | Amount | Class | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
541 | 406.0 | -2.312227 | 1.951992 | -1.609851 | 3.997906 | -0.522188 | -1.426545 | -2.537387 | 1.391657 | -2.770089 | ... | 0.517232 | -0.035049 | -0.465211 | 0.320198 | 0.044519 | 0.177840 | 0.261145 | -0.143276 | 0.00 | 1 |
623 | 472.0 | -3.043541 | -3.157307 | 1.088463 | 2.288644 | 1.359805 | -1.064823 | 0.325574 | -0.067794 | -0.270953 | ... | 0.661696 | 0.435477 | 1.375966 | -0.293803 | 0.279798 | -0.145362 | -0.252773 | 0.035764 | 529.00 | 1 |
4920 | 4462.0 | -2.303350 | 1.759247 | -0.359745 | 2.330243 | -0.821628 | -0.075788 | 0.562320 | -0.399147 | -0.238253 | ... | -0.294166 | -0.932391 | 0.172726 | -0.087330 | -0.156114 | -0.542628 | 0.039566 | -0.153029 | 239.93 | 1 |
6108 | 6986.0 | -4.397974 | 1.358367 | -2.592844 | 2.679787 | -1.128131 | -1.706536 | -3.496197 | -0.248778 | -0.247768 | ... | 0.573574 | 0.176968 | -0.436207 | -0.053502 | 0.252405 | -0.657488 | -0.827136 | 0.849573 | 59.00 | 1 |
6329 | 7519.0 | 1.234235 | 3.019740 | -4.304597 | 4.732795 | 3.624201 | -1.357746 | 1.713445 | -0.496358 | -1.282858 | ... | -0.379068 | -0.704181 | -0.656805 | -1.632653 | 1.488901 | 0.566797 | -0.010016 | 0.146793 | 1.00 | 1 |
5 rows × 31 columns
Let's visualize all numerical features in both density plot and box plot. Note any observations.
external_factors = ['Time','Amount']
print('\033[1mNumeric Features Distribution'.center(100))
figsize = (12, 4)
n=len(external_factors)
colors = ['g', 'b', 'r', 'y', 'k']
# histogram
plt.figure(figsize=figsize)
for i in range(len(external_factors)):
plt.subplot(1,n,i+1)
sns.distplot(df[external_factors[i]],
bins=100,
color = colors[i])
plt.tight_layout();
Numeric Features Distribution
/Users/michael/opt/anaconda3/lib/python3.9/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /Users/michael/opt/anaconda3/lib/python3.9/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
print('\033[1m Fraud Numeric Features Distribution'.center(100))
plt.figure(figsize=figsize)
for i in range(len(external_factors)):
plt.subplot(1,n,i+1)
sns.distplot(fraud[external_factors[i]],
bins=50,
color = colors[i])
plt.tight_layout();
Fraud Numeric Features Distribution
/Users/michael/opt/anaconda3/lib/python3.9/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /Users/michael/opt/anaconda3/lib/python3.9/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
print('\033[1m Normal Numeric Features Distribution'.center(100))
plt.figure(figsize=figsize)
for i in range(len(external_factors)):
plt.subplot(1,n,i+1)
sns.distplot(normal[external_factors[i]],
bins=1000,
color = colors[i])
plt.tight_layout();
Normal Numeric Features Distribution
/Users/michael/opt/anaconda3/lib/python3.9/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning) /Users/michael/opt/anaconda3/lib/python3.9/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). warnings.warn(msg, FutureWarning)
#plot time seris hue=class
plt.figure(figsize = (10,8))
sns.scatterplot(x = df['Time'],y = df['Amount'],hue=df['Class'])
<AxesSubplot:xlabel='Time', ylabel='Amount'>
df.boxplot(column=external_factors, grid=False, figsize=(6,4))
<AxesSubplot:>
fraud.boxplot(column=['Time','Amount'], grid=False, figsize=(6,4))
<AxesSubplot:>
normal.boxplot(column=['Amount'], grid=False, figsize=(6,4))
<AxesSubplot:>
fraud.boxplot(column=['Time'], grid=False, figsize=(6,4))
<AxesSubplot:>
X = df.drop(['Class'], axis=1)
y = df['Class']
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2, random_state = 43)
plt.figure(figsize=(12, 10))
cor = X_train.corr()
sns.heatmap(cor, vmin=-1, vmax=1);
keep_columns = np.full(cor.shape[0], True)
for i in range(cor.shape[0] - 1):
for j in range(i + 1, cor.shape[0] - 1):
if (np.abs(cor.iloc[i, j]) >= 0.8): # 0.8 is the correlation threshold
keep_columns[j] = False
selected_columns = X_train.columns[keep_columns]
X_train_reduced = X_train[selected_columns]
print(selected_columns)
Index(['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount'], dtype='object')
selected_columns==X.columns
array([ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True])
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, roc_curve,RocCurveDisplay
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from sklearn.svm import SVC
rf = RandomForestClassifier(random_state=42)
lr = LogisticRegression(max_iter=100000, class_weight="balanced",random_state=42)
xg = XGBClassifier(random_state=42)
svc = SVC(random_state=42)
clf_list = [lr,svc]
# Function to print the accuracy and classification report
def print_acc_and_CR(y_test, model_predictions):
# Calculate and display the model accuracy
print('Accuracy: {:.2f}%'.format(accuracy_score(y_test, model_predictions) * 100))
# Calculate and display the classification report for the model
print('Classification report: \n', classification_report(y_test, model_predictions))
# Function to display the confusion matrix
def display_confusion_matrix(y_test, model_predictions):
# Calculate and display the confusion matrix
model_confusionMatrix = confusion_matrix(y_test, model_predictions)
# Create variables that will be displayed as text on the plot
strings2 = np.asarray([['True Negatives \n', 'False Positives \n'], ['False Negatives \n', 'True Positives \n']])
labels2 = (np.asarray(["{0} {1:g}".format(string, value)
for string, value in zip(strings2.flatten(),model_confusionMatrix.flatten())])
).reshape(2, 2)
# Use a heat map plot to display the results
sns.heatmap(model_confusionMatrix, annot=labels2, fmt='', vmin=0, annot_kws={"fontsize":17})
plt.xlabel('Predicted value');
plt.ylabel('Actual value');
plt.show()
for c in clf_list:
clf = make_pipeline(StandardScaler(), c)
clf.fit(X_train_reduced, y_train)
y_pred = clf.predict(X_test)
print("***************Model************", c)
print_acc_and_CR(y_test, y_pred)
y_score = c.decision_function(X_test)
fpr, tpr, thresholds = roc_curve(y_test, y_score, pos_label=c.classes_[1])
y_score = clf.decision_function(X_test)
roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot()
plt.show()
display_confusion_matrix(y_test, y_pred)
***************Model************ LogisticRegression(class_weight='balanced', max_iter=100000, random_state=42) Accuracy: 97.77% Classification report: precision recall f1-score support 0 1.00 0.98 0.99 56856 1 0.07 0.91 0.13 106 accuracy 0.98 56962 macro avg 0.54 0.94 0.56 56962 weighted avg 1.00 0.98 0.99 56962
/Users/michael/opt/anaconda3/lib/python3.9/site-packages/sklearn/base.py:443: UserWarning: X has feature names, but LogisticRegression was fitted without feature names warnings.warn(
***************Model************ SVC(random_state=42) Accuracy: 99.93% Classification report: precision recall f1-score support 0 1.00 1.00 1.00 56856 1 0.94 0.69 0.79 106 accuracy 1.00 56962 macro avg 0.97 0.84 0.90 56962 weighted avg 1.00 1.00 1.00 56962
/Users/michael/opt/anaconda3/lib/python3.9/site-packages/sklearn/base.py:443: UserWarning: X has feature names, but SVC was fitted without feature names warnings.warn(
tree_list = [rf, xg]
for c in tree_list:
clf = make_pipeline(StandardScaler(), c)
clf.fit(X_train_reduced, y_train)
y_pred = clf.predict(X_test)
print("***************Model************", c)
print_acc_and_CR(y_test, y_pred)
display_confusion_matrix(y_test, y_pred)
***************Model************ RandomForestClassifier(random_state=42) Accuracy: 99.94% Classification report: precision recall f1-score support 0 1.00 1.00 1.00 56856 1 0.93 0.75 0.83 106 accuracy 1.00 56962 macro avg 0.96 0.88 0.92 56962 weighted avg 1.00 1.00 1.00 56962
***************Model************ XGBClassifier(base_score=0.5, booster='gbtree', callbacks=None, colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1, early_stopping_rounds=None, enable_categorical=False, eval_metric=None, gamma=0, gpu_id=-1, grow_policy='depthwise', importance_type=None, interaction_constraints='', learning_rate=0.300000012, max_bin=256, max_cat_to_onehot=4, max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1, missing=nan, monotone_constraints='()', n_estimators=100, n_jobs=0, num_parallel_tree=1, predictor='auto', random_state=42, reg_alpha=0, reg_lambda=1, ...) Accuracy: 99.95% Classification report: precision recall f1-score support 0 1.00 1.00 1.00 56856 1 0.94 0.77 0.85 106 accuracy 1.00 56962 macro avg 0.97 0.89 0.92 56962 weighted avg 1.00 1.00 1.00 56962
# Set a number of features to consider when looking at rank/importance
n_features = 20
# Use Random Forest to get feature ranks/importances for each feature
importances = rf.feature_importances_
std = np.std([tree.feature_importances_ for tree in rf.estimators_],
axis=0)
indices = np.argsort(importances)[::-1]
# Print the feature ranking
print("Feature ranking:")
# Look at importance for first 20 features
for f in range(n_features):#range(X_train_reduced.shape[1]):
print("%d. %s (feature %d) (%f)" %
(f + 1, X_train_reduced.columns[indices[f]], indices[f], importances[indices[f]]))
# Plot the impurity-based feature importances of the forest
plt.figure()
plt.title("Feature importances")
plt.bar(range(n_features), importances[indices[:n_features]],
color="r", yerr=std[indices[:n_features]], align="center")
plt.xticks(range(n_features), indices[:n_features])
plt.xlim([-1, n_features]);
Feature ranking: 1. V17 (feature 17) (0.154722) 2. V14 (feature 14) (0.137542) 3. V12 (feature 12) (0.135451) 4. V16 (feature 16) (0.090930) 5. V10 (feature 10) (0.078660) 6. V11 (feature 11) (0.042432) 7. V18 (feature 18) (0.031745) 8. V9 (feature 9) (0.030973) 9. V4 (feature 4) (0.026802) 10. V7 (feature 7) (0.024217) 11. V26 (feature 26) (0.022822) 12. V3 (feature 3) (0.019078) 13. V21 (feature 21) (0.015852) 14. V20 (feature 20) (0.014621) 15. Time (feature 0) (0.013952) 16. V27 (feature 27) (0.013382) 17. V1 (feature 1) (0.013057) 18. V6 (feature 6) (0.013050) 19. V8 (feature 8) (0.013014) 20. Amount (feature 29) (0.012751)
from supervised.automl import AutoML
automl = AutoML(algorithms=['LightGBM','Xgboost','Random Forest','Neural Network'],
train_ensemble=False, explain_level=2)
automl.fit(X_train_reduced, y_train)
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) RuntimeError: module compiled against API version 0xf but this version of numpy is 0xe
AutoML directory: AutoML_4 The task is binary_classification with evaluation metric logloss AutoML will use algorithms: ['LightGBM', 'Xgboost', 'Random Forest', 'Neural Network'] AutoML steps: ['simple_algorithms', 'default_algorithms'] Skip simple_algorithms because no parameters were generated. * Step default_algorithms will try to check up to 4 models