176 lines
No EOL
8 KiB
JavaScript
176 lines
No EOL
8 KiB
JavaScript
import { __awaiter, __generator } from "tslib";
|
|
import * as tf from '@tensorflow/tfjs-core';
|
|
import { getModelUris } from './common/getModelUris';
|
|
import { loadWeightMap } from './dom';
|
|
import { env } from './env';
|
|
var NeuralNetwork = /** @class */ (function () {
|
|
function NeuralNetwork(_name) {
|
|
this._name = _name;
|
|
this._params = undefined;
|
|
this._paramMappings = [];
|
|
}
|
|
Object.defineProperty(NeuralNetwork.prototype, "params", {
|
|
get: function () { return this._params; },
|
|
enumerable: true,
|
|
configurable: true
|
|
});
|
|
Object.defineProperty(NeuralNetwork.prototype, "paramMappings", {
|
|
get: function () { return this._paramMappings; },
|
|
enumerable: true,
|
|
configurable: true
|
|
});
|
|
Object.defineProperty(NeuralNetwork.prototype, "isLoaded", {
|
|
get: function () { return !!this.params; },
|
|
enumerable: true,
|
|
configurable: true
|
|
});
|
|
NeuralNetwork.prototype.getParamFromPath = function (paramPath) {
|
|
var _a = this.traversePropertyPath(paramPath), obj = _a.obj, objProp = _a.objProp;
|
|
return obj[objProp];
|
|
};
|
|
NeuralNetwork.prototype.reassignParamFromPath = function (paramPath, tensor) {
|
|
var _a = this.traversePropertyPath(paramPath), obj = _a.obj, objProp = _a.objProp;
|
|
obj[objProp].dispose();
|
|
obj[objProp] = tensor;
|
|
};
|
|
NeuralNetwork.prototype.getParamList = function () {
|
|
var _this = this;
|
|
return this._paramMappings.map(function (_a) {
|
|
var paramPath = _a.paramPath;
|
|
return ({
|
|
path: paramPath,
|
|
tensor: _this.getParamFromPath(paramPath)
|
|
});
|
|
});
|
|
};
|
|
NeuralNetwork.prototype.getTrainableParams = function () {
|
|
return this.getParamList().filter(function (param) { return param.tensor instanceof tf.Variable; });
|
|
};
|
|
NeuralNetwork.prototype.getFrozenParams = function () {
|
|
return this.getParamList().filter(function (param) { return !(param.tensor instanceof tf.Variable); });
|
|
};
|
|
NeuralNetwork.prototype.variable = function () {
|
|
var _this = this;
|
|
this.getFrozenParams().forEach(function (_a) {
|
|
var path = _a.path, tensor = _a.tensor;
|
|
_this.reassignParamFromPath(path, tensor.variable());
|
|
});
|
|
};
|
|
NeuralNetwork.prototype.freeze = function () {
|
|
var _this = this;
|
|
this.getTrainableParams().forEach(function (_a) {
|
|
var path = _a.path, variable = _a.tensor;
|
|
var tensor = tf.tensor(variable.dataSync());
|
|
variable.dispose();
|
|
_this.reassignParamFromPath(path, tensor);
|
|
});
|
|
};
|
|
NeuralNetwork.prototype.dispose = function (throwOnRedispose) {
|
|
if (throwOnRedispose === void 0) { throwOnRedispose = true; }
|
|
this.getParamList().forEach(function (param) {
|
|
if (throwOnRedispose && param.tensor.isDisposed) {
|
|
throw new Error("param tensor has already been disposed for path " + param.path);
|
|
}
|
|
param.tensor.dispose();
|
|
});
|
|
this._params = undefined;
|
|
};
|
|
NeuralNetwork.prototype.serializeParams = function () {
|
|
return new Float32Array(this.getParamList()
|
|
.map(function (_a) {
|
|
var tensor = _a.tensor;
|
|
return Array.from(tensor.dataSync());
|
|
})
|
|
.reduce(function (flat, arr) { return flat.concat(arr); }));
|
|
};
|
|
NeuralNetwork.prototype.load = function (weightsOrUrl) {
|
|
return __awaiter(this, void 0, void 0, function () {
|
|
return __generator(this, function (_a) {
|
|
switch (_a.label) {
|
|
case 0:
|
|
if (weightsOrUrl instanceof Float32Array) {
|
|
this.extractWeights(weightsOrUrl);
|
|
return [2 /*return*/];
|
|
}
|
|
return [4 /*yield*/, this.loadFromUri(weightsOrUrl)];
|
|
case 1:
|
|
_a.sent();
|
|
return [2 /*return*/];
|
|
}
|
|
});
|
|
});
|
|
};
|
|
NeuralNetwork.prototype.loadFromUri = function (uri) {
|
|
return __awaiter(this, void 0, void 0, function () {
|
|
var weightMap;
|
|
return __generator(this, function (_a) {
|
|
switch (_a.label) {
|
|
case 0:
|
|
if (uri && typeof uri !== 'string') {
|
|
throw new Error(this._name + ".loadFromUri - expected model uri");
|
|
}
|
|
return [4 /*yield*/, loadWeightMap(uri, this.getDefaultModelName())];
|
|
case 1:
|
|
weightMap = _a.sent();
|
|
this.loadFromWeightMap(weightMap);
|
|
return [2 /*return*/];
|
|
}
|
|
});
|
|
});
|
|
};
|
|
NeuralNetwork.prototype.loadFromDisk = function (filePath) {
|
|
return __awaiter(this, void 0, void 0, function () {
|
|
var readFile, _a, manifestUri, modelBaseUri, fetchWeightsFromDisk, loadWeights, manifest, _b, _c, weightMap;
|
|
return __generator(this, function (_d) {
|
|
switch (_d.label) {
|
|
case 0:
|
|
if (filePath && typeof filePath !== 'string') {
|
|
throw new Error(this._name + ".loadFromDisk - expected model file path");
|
|
}
|
|
readFile = env.getEnv().readFile;
|
|
_a = getModelUris(filePath, this.getDefaultModelName()), manifestUri = _a.manifestUri, modelBaseUri = _a.modelBaseUri;
|
|
fetchWeightsFromDisk = function (filePaths) { return Promise.all(filePaths.map(function (filePath) { return readFile(filePath).then(function (buf) { return buf.buffer; }); })); };
|
|
loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk);
|
|
_c = (_b = JSON).parse;
|
|
return [4 /*yield*/, readFile(manifestUri)];
|
|
case 1:
|
|
manifest = _c.apply(_b, [(_d.sent()).toString()]);
|
|
return [4 /*yield*/, loadWeights(manifest, modelBaseUri)];
|
|
case 2:
|
|
weightMap = _d.sent();
|
|
this.loadFromWeightMap(weightMap);
|
|
return [2 /*return*/];
|
|
}
|
|
});
|
|
});
|
|
};
|
|
NeuralNetwork.prototype.loadFromWeightMap = function (weightMap) {
|
|
var _a = this.extractParamsFromWeigthMap(weightMap), paramMappings = _a.paramMappings, params = _a.params;
|
|
this._paramMappings = paramMappings;
|
|
this._params = params;
|
|
};
|
|
NeuralNetwork.prototype.extractWeights = function (weights) {
|
|
var _a = this.extractParams(weights), paramMappings = _a.paramMappings, params = _a.params;
|
|
this._paramMappings = paramMappings;
|
|
this._params = params;
|
|
};
|
|
NeuralNetwork.prototype.traversePropertyPath = function (paramPath) {
|
|
if (!this.params) {
|
|
throw new Error("traversePropertyPath - model has no loaded params");
|
|
}
|
|
var result = paramPath.split('/').reduce(function (res, objProp) {
|
|
if (!res.nextObj.hasOwnProperty(objProp)) {
|
|
throw new Error("traversePropertyPath - object does not have property " + objProp + ", for path " + paramPath);
|
|
}
|
|
return { obj: res.nextObj, objProp: objProp, nextObj: res.nextObj[objProp] };
|
|
}, { nextObj: this.params });
|
|
var obj = result.obj, objProp = result.objProp;
|
|
if (!obj || !objProp || !(obj[objProp] instanceof tf.Tensor)) {
|
|
throw new Error("traversePropertyPath - parameter is not a tensor, for path " + paramPath);
|
|
}
|
|
return { obj: obj, objProp: objProp };
|
|
};
|
|
return NeuralNetwork;
|
|
}());
|
|
export { NeuralNetwork };
|
|
//# sourceMappingURL=NeuralNetwork.js.map
|