Skip to content

Commit b71bde3

Browse files
committed
Overhaul logistic regression scripts; fix a number of bugs.
1 parent 26eda05 commit b71bde3

File tree

10 files changed

+71
-123
lines changed

10 files changed

+71
-123
lines changed

methods/matlab/LOGISTIC_REGRESSION.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function logistic_regression(cmd)
2626
X = csvread(regressorsFile{:});
2727

2828
if isempty(responsesFile)
29-
y = X(:,end);
29+
y = X(:,end) + 1; % We have to increment because labels must be positive.
3030
X = X(:,1:end-1);
3131
else
3232
y = csvread(responsesFile{:});
@@ -47,7 +47,7 @@ function logistic_regression(cmd)
4747
disp(sprintf('[INFO ] total_time: %fs', toc(total_time)))
4848

4949
if ~isempty(testFile)
50-
csvwrite('predictions.csv', idx);
50+
csvwrite('predictions.csv', idx - 1); % Subtract extra label bit.
5151
csvwrite('matlab_lr_probs.csv', predictions);
5252
end
5353

methods/matlab/logistic_regression.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def RunMetrics(self, options):
7979

8080
# If the dataset contains two files then the second file is the test
8181
# file. In this case we add this to the command line.
82-
if len(self.dataset) == 2:
82+
if len(self.dataset) >= 2:
8383
inputCmd = "-i " + self.dataset[0] + " -t " + self.dataset[1]
8484
else:
8585
inputCmd = "-i " + self.dataset[0]
@@ -111,11 +111,15 @@ def RunMetrics(self, options):
111111
truelabels = np.genfromtxt(self.dataset[2], delimiter = ',')
112112
metrics['Runtime'] = timer.total_time
113113
confusionMatrix = Metrics.ConfusionMatrix(truelabels, predictions)
114-
metrics['ACC'] = Metrics.AverageAccuracy(confusionMatrix)
115-
metrics['MCC'] = Metrics.MCCMultiClass(confusionMatrix)
116-
metrics['Precision'] = Metrics.AvgPrecision(confusionMatrix)
117-
metrics['Recall'] = Metrics.AvgRecall(confusionMatrix)
118-
metrics['MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictions)
114+
115+
metrics['Avg Accuracy'] = Metrics.AverageAccuracy(confusionMatrix)
116+
metrics['MultiClass Precision'] = Metrics.AvgPrecision(confusionMatrix)
117+
metrics['MultiClass Recall'] = Metrics.AvgRecall(confusionMatrix)
118+
metrics['MultiClass FMeasure'] = Metrics.AvgFMeasure(confusionMatrix)
119+
metrics['MultiClass Lift'] = Metrics.LiftMultiClass(confusionMatrix)
120+
metrics['MultiClass MCC'] = Metrics.MCCMultiClass(confusionMatrix)
121+
metrics['MultiClass Information'] = Metrics.AvgMPIArray(confusionMatrix, truelabels, predictions)
122+
metrics['Simple MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictions)
119123

120124
Log.Info(("total time: %fs" % (metrics['Runtime'])), self.verbose)
121125

methods/milk/logistic_regression.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def RunLogisticRegressionMilk():
8080
self.model = self.BuildModel()
8181
with totalTimer:
8282
self.model = self.model.train(trainData, labels)
83+
if len(self.dataset) > 1:
84+
# We get back probabilities; cast these to classes.
85+
self.predictions = np.greater(self.model.apply(testData), 0.5)
8386
except Exception as e:
8487
return -1
8588

@@ -112,4 +115,19 @@ def RunMetrics(self, options):
112115

113116
# Datastructure to store the results.
114117
metrics = {'Runtime' : results}
118+
119+
if len(self.dataset) >= 3:
120+
truelabels = LoadDataset(self.dataset[2])
121+
122+
confusionMatrix = Metrics.ConfusionMatrix(truelabels, self.predictions)
123+
124+
metrics['Avg Accuracy'] = Metrics.AverageAccuracy(confusionMatrix)
125+
metrics['MultiClass Precision'] = Metrics.AvgPrecision(confusionMatrix)
126+
metrics['MultiClass Recall'] = Metrics.AvgRecall(confusionMatrix)
127+
metrics['MultiClass FMeasure'] = Metrics.AvgFMeasure(confusionMatrix)
128+
metrics['MultiClass Lift'] = Metrics.LiftMultiClass(confusionMatrix)
129+
metrics['MultiClass MCC'] = Metrics.MCCMultiClass(confusionMatrix)
130+
metrics['MultiClass Information'] = Metrics.AvgMPIArray(confusionMatrix, truelabels, self.predictions)
131+
metrics['Simple MSE'] = Metrics.SimpleMeanSquaredError(truelabels, self.predictions)
132+
115133
return metrics

methods/mlpack/logistic_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def OptionsToStr(self, options):
9898
optionsStr = "-e " + str(options.pop("epsilon"))
9999
if "max_iterations" in options:
100100
optionsStr = optionsStr + " -n " + str(options.pop("max_iterations"))
101-
if "optimizer" in options:
101+
if "algorithm" in options:
102102
optionsStr = optionsStr + " -O " + str(options.pop("optimizer"))
103103
if "step_size" in options:
104104
optionsStr = optionsStr + " -s " + str(options.pop("step_size"))

methods/scikit/logistic_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def RunLogisticRegressionScikit():
8585
# Use the last row of the training set as the responses.
8686
X, y = SplitTrainData(self.dataset)
8787
if "algorithm" in options:
88-
self.opts["algorithm"] = str(options.pop("algorithm"))
88+
self.opts["solver"] = str(options.pop("algorithm"))
8989
if "epsilon" in options:
9090
self.opts["epsilon"] = float(options.pop("epsilon"))
9191
if "max_iterations" in options:

methods/shogun/logistic_regression.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self, dataset, timeout=0, verbose=True):
5151
self.predictions = None
5252
self.z = 1
5353
self.model = None
54+
self.max_iter = None
5455

5556
'''
5657
Build the model for the Logistic Regression.
@@ -63,6 +64,8 @@ def BuildModel(self, data, responses):
6364
# Create and train the classifier.
6465
model = MulticlassLogisticRegression(self.z, RealFeatures(data.T),
6566
MulticlassLabels(responses))
67+
if self.max_iter is not None:
68+
model.set_max_iter(self.max_iter);
6669
model.train()
6770
return model
6871

@@ -87,6 +90,10 @@ def RunLogisticRegressionShogun():
8790
# Use the last row of the training set as the responses.
8891
X, y = SplitTrainData(self.dataset)
8992

93+
# Get the maximum number of iterations.
94+
if "max_iterations" in options:
95+
self.max_iter = int(options.pop("max_iterations"))
96+
9097
# Get the regularization value.
9198
if "lambda" in options:
9299
self.z = float(options.pop("lambda"))

methods/weka/logistic_regression.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def __del__(self):
6969
def RunMetrics(self, options):
7070
Log.Info("Perform Logistic Regression.", self.verbose)
7171

72+
maxIterStr = ""
73+
if 'max_iterations' in options:
74+
maxIterStr = " -m " + str(options['max_iterations']) + " "
75+
options.pop('max_iterations')
76+
7277
if len(options) > 0:
7378
Log.Fatal("Unknown parameters: " + str(options))
7479
raise Exception("unknown parameters")
@@ -79,8 +84,8 @@ def RunMetrics(self, options):
7984

8085
# Split the command using shell-like syntax.
8186
cmd = shlex.split("java -classpath " + self.path + "/weka.jar" +
82-
":methods/weka" + " LOGISTICREGRESSION -t " + self.dataset[0] + " -T " +
83-
self.dataset[1])
87+
":methods/weka" + " LogisticRegression -t " + self.dataset[0] + " -T " +
88+
self.dataset[1] + maxIterStr)
8489

8590
# Run command with the nessecary arguments and return its output as a byte
8691
# string. We have untrusted input so we disable all shell based features.
@@ -105,11 +110,14 @@ def RunMetrics(self, options):
105110
truelabels = np.genfromtxt(self.dataset[2], delimiter = ',')
106111
metrics['Runtime'] = timer.total_time
107112
confusionMatrix = Metrics.ConfusionMatrix(truelabels, predictions)
108-
metrics['ACC'] = Metrics.AverageAccuracy(confusionMatrix)
109-
metrics['MCC'] = Metrics.MCCMultiClass(confusionMatrix)
110-
metrics['Precision'] = Metrics.AvgPrecision(confusionMatrix)
111-
metrics['Recall'] = Metrics.AvgRecall(confusionMatrix)
112-
metrics['MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictions)
113+
metrics['Avg Accuracy'] = Metrics.AverageAccuracy(confusionMatrix)
114+
metrics['MultiClass Precision'] = Metrics.AvgPrecision(confusionMatrix)
115+
metrics['MultiClass Recall'] = Metrics.AvgRecall(confusionMatrix)
116+
metrics['MultiClass FMeasure'] = Metrics.AvgFMeasure(confusionMatrix)
117+
metrics['MultiClass Lift'] = Metrics.LiftMultiClass(confusionMatrix)
118+
metrics['MultiClass MCC'] = Metrics.MCCMultiClass(confusionMatrix)
119+
metrics['MultiClass Information'] = Metrics.AvgMPIArray(confusionMatrix, truelabels, predictions)
120+
metrics['Simple MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictions)
113121

114122
Log.Info(("total time: %fs" % (metrics['Runtime'])), self.verbose)
115123

methods/weka/src/LOGISTICREGRESSION.java

Lines changed: 0 additions & 102 deletions
This file was deleted.

methods/weka/src/LogisticRegression.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.io.IOException;
99
import weka.core.*;
1010
import weka.core.converters.ConverterUtils.DataSource;
11+
import weka.core.converters.CSVLoader;
1112
import weka.filters.Filter;
1213
import weka.filters.unsupervised.attribute.NumericToNominal;
1314

@@ -29,7 +30,8 @@ public class LogisticRegression {
2930
+ " the last row of the input file.\n\n"
3031
+ "Options:\n\n"
3132
+ "-t [string] Optional file containing containing\n"
32-
+ " test dataset");
33+
+ " test dataset\n"
34+
+ "-m [int] Maximum number of iterations\n");
3335

3436
public static HashMap<Integer, Double> createClassMap(Instances Data) {
3537
HashMap<Integer, Double> classMap = new HashMap<Integer, Double>();
@@ -69,6 +71,8 @@ public static void main(String args[]) {
6971

7072
// Load input dataset.
7173
DataSource source = new DataSource(regressorsFile);
74+
if (source.getLoader() instanceof CSVLoader)
75+
((CSVLoader) source.getLoader()).setNoHeaderRowPresent(true);
7276
Instances data = source.getDataSet();
7377

7478
// Transform numeric class to nominal class because the
@@ -81,12 +85,19 @@ public static void main(String args[]) {
8185
nm.setInputFormat(data);
8286
data = Filter.useFilter(data, nm);
8387

88+
boolean hasMaxIters = false;
89+
int maxIter = Integer.parseInt(Utils.getOption('m', args));
90+
if (maxIter != 0)
91+
hasMaxIters = true;
92+
8493
// Did the user pass a test file?
8594
String testFile = Utils.getOption('t', args);
8695
Instances testData = null;
8796
if (testFile.length() != 0)
8897
{
8998
source = new DataSource(testFile);
99+
if (source.getLoader() instanceof CSVLoader)
100+
((CSVLoader) source.getLoader()).setNoHeaderRowPresent(true);
90101
testData = source.getDataSet();
91102

92103
// Weka makes the assumption that the structure of the training and test
@@ -122,6 +133,8 @@ public static void main(String args[]) {
122133
// Perform Logistic Regression.
123134
timer.StartTimer("total_time");
124135
weka.classifiers.functions.Logistic model = new weka.classifiers.functions.Logistic();
136+
if (hasMaxIters)
137+
model.setMaxIts(maxIter);
125138
model.buildClassifier(data);
126139

127140
// Use the testdata to evaluate the modell.
@@ -140,7 +153,7 @@ public static void main(String args[]) {
140153
}
141154
FileWriter writer = new FileWriter(probabs.getName(), false);
142155

143-
File predictions = new File("weka_lr_predictions.csv");
156+
File predictions = new File("weka_predicted.csv");
144157
if(!predictions.exists()) {
145158
predictions.createNewFile();
146159
}

util/timer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ def timeout(fun, timeout=9000):
6464
p.join()
6565

6666
Log.Warn("Script timed out after " + str(timeout) + " seconds")
67-
return -2
67+
return [-2]
6868
else:
6969
try:
7070
r = q.get(timeout=3)
7171
except Exception as e:
72-
r = -1
72+
r = [-1]
7373
return r

0 commit comments

Comments
 (0)