diff --git a/package-lock.json b/package-lock.json index 77113e7..397e2f9 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10,6 +10,7 @@ "license": "MIT", "dependencies": { "@tensorflow/tfjs-node": "^4.1.0", + "@tensorflow/tfjs-node-gpu": "^4.1.0", "tslib": "^2.4.1" }, "devDependencies": { @@ -2059,6 +2060,36 @@ "node": ">=8.11.0" } }, + "node_modules/@tensorflow/tfjs-node-gpu": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/@tensorflow/tfjs-node-gpu/-/tfjs-node-gpu-4.1.0.tgz", + "integrity": "sha512-7lPoiz7m3Cvz1uxr+YGwGn6Ko1kDmsN9VZYVicZXzhshU7hDVoTBPI6DAS2na5nK9rCKBqQF/TQeQy9sz95BPg==", + "hasInstallScript": true, + "dependencies": { + "@mapbox/node-pre-gyp": "1.0.9", + "@tensorflow/tfjs": "4.1.0", + "adm-zip": "^0.5.2", + "google-protobuf": "^3.9.2", + "https-proxy-agent": "^2.2.1", + "progress": "^2.0.0", + "rimraf": "^2.6.2", + "tar": "^4.4.6" + }, + "engines": { + "node": ">=8.11.0" + } + }, + "node_modules/@tensorflow/tfjs-node-gpu/node_modules/rimraf": { + "version": "2.7.1", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz", + "integrity": "sha512-uWjbaKIK3T1OSVptzX7Nl6PvQ3qAGtKEtVRjRuazjfL3Bx5eI409VZSqgND+4UNnmzLVdPj9FqFJNPqBZFve4w==", + "dependencies": { + "glob": "^7.1.3" + }, + "bin": { + "rimraf": "bin.js" + } + }, "node_modules/@tensorflow/tfjs-node/node_modules/rimraf": { "version": "2.7.1", "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz", @@ -13519,6 +13550,31 @@ } } }, + "@tensorflow/tfjs-node-gpu": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/@tensorflow/tfjs-node-gpu/-/tfjs-node-gpu-4.1.0.tgz", + "integrity": "sha512-7lPoiz7m3Cvz1uxr+YGwGn6Ko1kDmsN9VZYVicZXzhshU7hDVoTBPI6DAS2na5nK9rCKBqQF/TQeQy9sz95BPg==", + "requires": { + "@mapbox/node-pre-gyp": "1.0.9", + "@tensorflow/tfjs": "4.1.0", + "adm-zip": "^0.5.2", + "google-protobuf": "^3.9.2", + "https-proxy-agent": "^2.2.1", + "progress": "^2.0.0", + "rimraf": "^2.6.2", + "tar": "^4.4.6" + }, + "dependencies": { + "rimraf": { + "version": "2.7.1", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-2.7.1.tgz", + "integrity": "sha512-uWjbaKIK3T1OSVptzX7Nl6PvQ3qAGtKEtVRjRuazjfL3Bx5eI409VZSqgND+4UNnmzLVdPj9FqFJNPqBZFve4w==", + "requires": { + "glob": "^7.1.3" + } + } + } + }, "@trysound/sax": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/@trysound/sax/-/sax-0.2.0.tgz", diff --git a/package.json b/package.json index 0f2b518..bc0c318 100644 --- a/package.json +++ b/package.json @@ -6,6 +6,7 @@ "private": true, "dependencies": { "@tensorflow/tfjs-node": "^4.1.0", + "@tensorflow/tfjs-node-gpu": "^4.1.0", "tslib": "^2.4.1" }, "devDependencies": { @@ -30,4 +31,3 @@ "typescript": "4.8.4" } } - diff --git a/packages/tfjs-node-helpers/src/classification/binary-classification-trainer.ts b/packages/tfjs-node-helpers/src/classification/binary-classification-trainer.ts index bc7557a..ba3aa54 100644 --- a/packages/tfjs-node-helpers/src/classification/binary-classification-trainer.ts +++ b/packages/tfjs-node-helpers/src/classification/binary-classification-trainer.ts @@ -22,8 +22,10 @@ import { Metric } from '../testing/metric'; import { MetricCalculator } from '../testing/metric-calculator'; import { binarize } from '../utils/binarize'; import { FeatureNormalizer } from '../feature-engineering/feature-normalizer'; +import { switchHardwareUsage } from '../utils/switch-hardware-usage'; export type BinaryClassificationTrainerOptions = { + shouldUseGPU?: boolean; batchSize?: number; epochs?: number; patience?: number; @@ -53,6 +55,8 @@ export class BinaryClassificationTrainer { protected static DEFAULT_PATIENCE: number = 20; constructor(options: BinaryClassificationTrainerOptions) { + switchHardwareUsage(options.shouldUseGPU); + this.batchSize = options.batchSize ?? BinaryClassificationTrainer.DEFAULT_BATCH_SIZE; this.epochs = options.epochs ?? BinaryClassificationTrainer.DEFAULT_EPOCHS; this.patience = options.patience ?? BinaryClassificationTrainer.DEFAULT_PATIENCE; diff --git a/packages/tfjs-node-helpers/src/classification/binary-classifier.ts b/packages/tfjs-node-helpers/src/classification/binary-classifier.ts index 111a374..8286393 100644 --- a/packages/tfjs-node-helpers/src/classification/binary-classifier.ts +++ b/packages/tfjs-node-helpers/src/classification/binary-classifier.ts @@ -1,13 +1,17 @@ import { LayersModel, loadLayersModel, tensor, Tensor } from '@tensorflow/tfjs-node'; +import { switchHardwareUsage } from '../utils/switch-hardware-usage'; export type BinaryClassifierOptions = { model: LayersModel; + shouldUseGPU?: boolean; }; export class BinaryClassifier { protected model?: LayersModel; constructor(options?: BinaryClassifierOptions) { + switchHardwareUsage(options?.shouldUseGPU); + this.model = options?.model; } diff --git a/packages/tfjs-node-helpers/src/utils/switch-hardware-usage.ts b/packages/tfjs-node-helpers/src/utils/switch-hardware-usage.ts new file mode 100644 index 0000000..db19322 --- /dev/null +++ b/packages/tfjs-node-helpers/src/utils/switch-hardware-usage.ts @@ -0,0 +1,9 @@ +export const switchHardwareUsage = (shouldUseGPU?: boolean): void => { + if (shouldUseGPU) { + console.log('Attempting to use GPU.'); + require('@tensorflow/tfjs-node-gpu'); + } else { + console.log('Using CPU.'); + require('@tensorflow/tfjs-node'); + } +};