Skip to content

Commit

Permalink
feat: global optimization by direct algorithm (#103)
Browse files Browse the repository at this point in the history
* feat: add DIRECT optimizer and keep id

* fix: wrong conditional. close: #102

close: #98
  • Loading branch information
jobo322 committed Sep 23, 2022
1 parent 2621110 commit fb27ade
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 18 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"dependencies": {
"cheminfo-types": "^1.1.0",
"ml-array-max": "^1.2.4",
"ml-direct": "^0.1.1",
"ml-levenberg-marquardt": "^4.1.0",
"ml-peak-shape-generator": "^4.1.2",
"ml-spectra-processing": "^11.5.0"
Expand Down
68 changes: 68 additions & 0 deletions src/__tests__/globalOptimization.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import type { DataXY } from 'cheminfo-types';
import { toBeDeepCloseTo, toMatchCloseTo } from 'jest-matcher-deep-close-to';
import { generateSpectrum } from 'spectrum-generator';

import { optimize } from '../index';

expect.extend({ toBeDeepCloseTo, toMatchCloseTo });

describe('Optimize sum of Gaussians', () => {
const peaks = [
{ x: -0.5, y: 1, shape: { kind: 'gaussian' as const, fwhm: 0.05 } },
{ x: 0.5, y: 1, shape: { kind: 'gaussian' as const, fwhm: 0.05 } },
];

const data: DataXY = generateSpectrum(peaks, {
generator: {
from: -1,
to: 1,
nbPoints: 1024,
shape: { kind: 'gaussian' },
},
});

let result = optimize(
data,
[
{
x: -0.55,
y: 0.9,
shape: { kind: 'gaussian' as const, fwhm: 0.08 },
parameters: {
x: { min: -0.49, max: -0.512 },
y: { min: 0.9, max: 1.2 },
fwhm: { min: 0.04, max: 0.07 },
},
},
{
x: 0.55,
y: 0.9,
shape: { kind: 'gaussian' as const, fwhm: 0.08 },
parameters: {
x: { min: 0.49, max: 0.512 },
y: { min: 0.9, max: 1.2 },
fwhm: { min: 0.04, max: 0.07 },
},
},
],
{
optimization: {
kind: 'direct',
options: {
maxIterations: 20,
},
},
},
);
for (let i = 0; i < 2; i++) {
const peak = peaks[i];
for (const key in peak) {
//@ts-expect-error
const value = peak[key];
it(`peak at ${peak.x} key: ${key}`, () => {
//@ts-expect-error
expect(result.peaks[i][key]).toMatchCloseTo(value, 2);
});
}
}
});
45 changes: 28 additions & 17 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,29 +42,40 @@ type OptimizedPeakIDOrNot<T extends Peak> = T extends { id: string }

type OptimizationParameter = number | ((peak: Peak) => number);

interface GeneralAlgorithmOptions {
/** number of max iterations
* @default 100
*/
maxIterations?: number;
}
export interface LMOptimizationOptions extends GeneralAlgorithmOptions {
/** maximum time running before break in seconds */
timeout?: number;
/** damping factor
* @default 1.5
*/
damping?: number;
/** error tolerance
* @default 1e-8
*/
errorTolerance?: number;
}

export interface DirectOptimizationOptions extends GeneralAlgorithmOptions {
epsilon?: number;
tolerance?: number;
tolerance2?: number;
initialState?: any;
}

export interface OptimizationOptions {
/**
* kind of algorithm. By default it's levenberg-marquardt
*/
kind?: 'lm' | 'levenbergMarquardt';
kind?: 'lm' | 'levenbergMarquardt' | 'direct';

/** options for the specific kind of algorithm */
options?: {
/** maximum time running before break in seconds */
timeout?: number;
/** damping factor
* @default 1.5
*/
damping?: number;
/** number of max iterations
* @default 100
*/
maxIterations?: number;
/** error tolerance
* @default 1e-8
*/
errorTolerance?: number;
};
options?: DirectOptimizationOptions | LMOptimizationOptions;
}

export interface OptimizeOptions {
Expand Down
2 changes: 1 addition & 1 deletion src/shapes/getSumOfShapes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export function getSumOfShapes(internalPeaks: InternalPeak[]) {
for (const peak of internalPeaks) {
const peakX = parameters[peak.fromIndex];
const y = parameters[peak.fromIndex + 1];
for (let i = 2; i <= peak.toIndex; i++) {
for (let i = 2; i < parameters.length; i++) {
//@ts-expect-error Not simply to solve the issue
peak.shapeFct[peak.parameters[i]] = parameters[peak.fromIndex + i];
}
Expand Down
15 changes: 15 additions & 0 deletions src/util/selectMethod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { levenbergMarquardt } from 'ml-levenberg-marquardt';

import { OptimizationOptions } from '../index';

import { directOptimization } from './wrappers/directOptimization';

/** Algorithm to select the method.
* @param optimizationOptions - Optimization options
* @returns - The algorithm and optimization options
Expand All @@ -21,6 +23,19 @@ export function selectMethod(optimizationOptions: OptimizationOptions = {}) {
...options,
},
};
case 'direct': {
return {
algorithm: directOptimization,
optimizationOptions: {
iterations: 20,
epsilon: 1e-4,
tolerance: 1e-16,
tolerance2: 1e-12,
initialState: {},
...options,
},
};
}
default:
throw new Error(`Unknown fitting algorithm`);
}
Expand Down
50 changes: 50 additions & 0 deletions src/util/wrappers/directOptimization.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { DataXY } from 'cheminfo-types';
import direct from 'ml-direct';

export function directOptimization(
data: DataXY,
sumOfShapes: (parameters: number[]) => (x: number) => number,
options: any,
) {
const {
minValues,
maxValues,
maxIterations,
epsilon,
tolerance,
tolerance2,
initialState,
} = options;
const objectiveFunction = getObjectiveFunction(data, sumOfShapes);
const result = direct(objectiveFunction, minValues, maxValues, {
iterations: maxIterations,
epsilon,
tolerance,
tolerance2,
initialState,
});

const { optima } = result;

return {
parameterError: result.minFunctionValue,
iterations: result.iterations,
parameterValues: optima[0],
};
}

function getObjectiveFunction(
data: DataXY,
sumOfShapes: (parameters: number[]) => (x: number) => number,
) {
const { x, y } = data;
const nbPoints = x.length;
return (parameters: number[]) => {
const fct = sumOfShapes(parameters);
let error = 0;
for (let i = 0; i < nbPoints; i++) {
error += Math.pow(y[i] - fct(x[i]), 2);
}
return error;
};
}

0 comments on commit fb27ade

Please sign in to comment.