{ "cells": [ { "cell_type": "markdown", "source": [ "# Example 2:SVM\n", "This example try to show how to use our PLQ Composite Decomposition tool to decompose an SVM Loss function\n", "\n", "Given a loss function $L(\\beta) = \\frac{1}{2}||\\beta||^2 + C\\sum_{i=1}^n max(0, 1-y_i\\beta^Tx_i)$, we can decompose it into a PLQ composite function." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "source": [ "from plqcom import PLQLoss, plq_to_rehloss, affine_transformation\n", "import numpy as np\n", "from rehline import ReHLine" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" }, "ExecuteTime": { "end_time": "2025-01-03T09:25:33.530588Z", "start_time": "2025-01-03T09:25:31.686358Z" } }, "outputs": [], "execution_count": 1 }, { "cell_type": "markdown", "source": [ "# 1. Data Generation\n", "First, we generate a classification dataset with 10000 samples and 5 features.\n", "$\\mathbf{Y}=sgn(\\mathbf{X}\\mathbf{\\beta} + \\mathbf{\\epsilon})$ where $\\mathbf{X} \\in \\mathbb{R}^{10000 \\times 1}$, $\\mathbf{Y} \\in \\mathbb{R}^{10000}$, and $\\beta \\in \\mathbb{R}^{5}$\n", "The true parameter $\\mathbf{\\beta}$ is randomly generated." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "source": [ "n, d, C = 1000, 3, 0.5\n", "np.random.seed(1024)\n", "X = np.random.randn(1000, 3)\n", "beta = np.random.randn(3)\n", "y = np.sign(X.dot(beta) + np.random.randn(n))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" }, "ExecuteTime": { "end_time": "2025-01-03T09:25:33.540980Z", "start_time": "2025-01-03T09:25:33.538303Z" } }, "outputs": [], "execution_count": 2 }, { "cell_type": "markdown", "source": [ "Check the first 10 samples" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "source": [ "X_head = X[:10]\n", "y_head = y[:10]\n", "X_head, y_head" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" }, "ExecuteTime": { "end_time": "2025-01-03T09:25:33.551738Z", "start_time": "2025-01-03T09:25:33.548375Z" } }, "outputs": [ { "data": { "text/plain": [ "(array([[ 2.12444863, 0.25264613, 1.45417876],\n", " [ 0.56923979, 0.45822365, -0.80933344],\n", " [ 0.86407349, 0.20170137, -1.87529904],\n", " [-0.56850693, -0.06510141, 0.80681666],\n", " [-0.5778176 , 0.57306064, -0.33667496],\n", " [ 0.29700734, -0.37480416, 0.15510474],\n", " [ 0.70485719, 0.8452178 , -0.65818079],\n", " [ 0.56810558, 0.51538125, -0.61564998],\n", " [ 0.92611427, -1.28591285, 1.43014026],\n", " [-0.4254975 , -0.40257712, 0.60410409]]),\n", " array([ 1., -1., -1., -1., -1., 1., -1., -1., 1., 1.]))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 3 }, { "cell_type": "markdown", "source": [ "# 2.Create and Decompose the PLQ Loss" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "source": [ "# Create a PLQLoss object\n", "plqloss = PLQLoss(quad_coef={'a': np.array([0., 0.]), 'b': np.array([0., 1.]), 'c': np.array([0., 0.])},\n", " cutpoints=np.array([0]))\n", "# Decompose the SVM loss into PLQ composite loss\n", "rehloss = plq_to_rehloss(plqloss)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" }, "ExecuteTime": { "end_time": "2025-01-03T09:25:36.557759Z", "start_time": "2025-01-03T09:25:36.553207Z" } }, "outputs": [], "execution_count": 4 }, { "cell_type": "code", "source": [ "print(rehloss.relu_coef, rehloss.relu_intercept)\n", "print(rehloss.rehu_cut, rehloss.rehu_coef, rehloss.rehu_intercept)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" }, "ExecuteTime": { "end_time": "2025-01-03T09:25:37.193533Z", "start_time": "2025-01-03T09:25:37.190454Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1.]] [[-0.]]\n", "[] [] []\n" ] } ], "execution_count": 5 }, { "cell_type": "markdown", "source": [ "# 3. Broadcast to all samples" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "source": [ "# Broadcast the PLQ composite loss to all samples\n", "rehloss = affine_transformation(rehloss, n=X.shape[0], c=C, p=-y, q=1)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" }, "ExecuteTime": { "end_time": "2025-01-03T09:25:46.140308Z", "start_time": "2025-01-03T09:25:46.136811Z" } }, "outputs": [], "execution_count": 6 }, { "cell_type": "code", "source": [ "print(rehloss.relu_coef.shape)\n", "print(\"First ten sample relu coefficients: %s\" % rehloss.relu_coef[0][:10])\n", "print(\"First ten sample relu intercepts: %s\" % rehloss.relu_intercept[0][:10])" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" }, "ExecuteTime": { "end_time": "2025-01-03T09:25:47.351946Z", "start_time": "2025-01-03T09:25:47.349011Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 1000)\n", "First ten sample relu coefficients: [-0.5 0.5 0.5 0.5 0.5 -0.5 0.5 0.5 -0.5 -0.5]\n", "First ten sample relu intercepts: [0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5]\n" ] } ], "execution_count": 7 }, { "cell_type": "markdown", "source": [ "# 4. Use the ReHLine to solve the problem" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "source": [ "clf = ReHLine()\n", "clf.U, clf.V = rehloss.relu_coef, rehloss.relu_intercept\n", "clf.fit(X=X)\n", "print('sol privided by rehline: %s' % clf.coef_)\n", "print(clf.decision_function([[.1, .2, .3]]))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" }, "ExecuteTime": { "end_time": "2025-01-03T09:26:05.820895Z", "start_time": "2025-01-03T09:26:05.816607Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sol privided by rehline: [ 0.74100049 -0.00617326 2.66988444]\n", "[0.87383073]\n" ] } ], "execution_count": 8 } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }