当前位置:首页 > 文章列表 > Golang > Go问答 > 执行梯度降低

执行梯度降低

来源:stackoverflow 2024-03-04 19:15:28 0浏览 收藏

“纵有疾风来,人生不言弃”,这句话送给正在学习Golang的朋友们,也希望在阅读本文《执行梯度降低》后,能够真的帮助到大家。我也会在后续的文章中,陆续更新Golang相关的技术文章,有好的建议欢迎大家在评论留言,非常感谢!

问题内容

我正在尝试在 go 中实现梯度下降。我的目标是根据汽车的行驶里程来预测汽车的成本。 这是我的数据集:

km,price
240000,3650
139800,3800
150500,4400
185530,4450
176000,5250
114800,5350
166800,5800
89000,5990
144500,5999
84000,6200
82029,6390
63060,6390
74000,6600
97500,6800
67000,6800
76025,6900
48235,6900
93000,6990
60949,7490
65674,7555
54000,7990
68500,7990
22899,7990
61789,8290

我尝试了各种方法,例如规范化数据集、不规范化数据集、保留 thetas 不变、非规范化 thetas...但我无法得到正确的结果。 我的数学一定有什么地方不对劲,但我不知道哪里不对。 我想要得到的结果应该约为 t0 = 8500, t1 = -0.02 我的实现如下:

package main

import (
    "encoding/csv"
    "fmt"
    "log"
    "math"
    "os"
    "strconv"
)

const (
    dataFile     = "data.csv"
    iterations   = 20000
    learningRate = 0.1
)

type dataSet [][]float64

var minKm, maxKm, minPrice, maxPrice float64

func (d dataSet) getExtremes(column int) (float64, float64) {

    min := math.Inf(1)
    max := math.Inf(-1)
    for _, row := range d {
        item := row[column]
        if item > max {
            max = item
        }
        if item < min {
            min = item
        }
    }

    return min, max
}

func normalizeItem(item, min, max float64) float64 {

    return (item - min) / (max - min)
}

func (d *dataSet) normalize() {

    minKm, maxKm = d.getExtremes(0)
    minPrice, maxPrice = d.getExtremes(1)
    for _, row := range *d {
        row[0], row[1] = normalizeItem(row[0], minKm, maxKm), normalizeItem(row[1], minPrice, maxPrice)
    }
}

func processEntry(entry []string) []float64 {

    if len(entry) != 2 {
        log.Fatalln("expected two fields")
    }
    km, err := strconv.ParseFloat(entry[0], 64)
    if err != nil {
        log.Fatalln(err)
    }
    price, err := strconv.ParseFloat(entry[1], 64)
    if err != nil {
        log.Fatalln(err)
    }
    return []float64{km, price}
}

func getData() dataSet {

    file, err := os.Open(dataFile)
    if err != nil {
        log.Fatalln(err)
    }
    reader := csv.NewReader(file)
    entries, err := reader.ReadAll()
    if err != nil {
        log.Fatalln(err)
    }
    entries = entries[1:]

    data := make(dataSet, len(entries))
    for k, entry := range entries {
        data[k] = processEntry(entry)
    }
    return data
}

func outputResult(theta0, theta1 float64) {
    file, err := os.OpenFile("weights.csv", os.O_WRONLY, 0644)
    if err != nil {
        log.Fatalln(err)
    }
    defer file.Close()
    file.Truncate(0)
    file.Seek(0, 0)
    file.WriteString(fmt.Sprintf("theta0,%.6f\ntheta1,%.6f\n", theta0, theta1))
}

func estimatePrice(theta0, theta1, mileage float64) float64 {

    return theta0 + theta1*mileage
}

func (d dataSet) computeThetas(theta0, theta1 float64) (float64, float64) {

    dataSize := float64(len(d))
    t0sum, t1sum := 0.0, 0.0
    for _, it := range d {
        mileage := it[0]
        price := it[1]
        err := estimatePrice(theta0, theta1, mileage) - price
        t0sum += err
        t1sum += err * mileage
    }

    return theta0 - (t0sum / dataSize * learningRate), theta1 - (t1sum / dataSize * learningRate)
}

func denormalize(theta, min, max float64) float64 {

    return theta*(max-min) + min
}

func main() {

    data := getData()
    data.normalize()
    theta0, theta1 := 0.0, 0.0
    for k := 0; k < iterations; k++ {
        theta0, theta1 = data.computeThetas(theta0, theta1)
    }
    theta0 = denormalize(theta0, minKm, maxKm)
    theta1 = denormalize(theta1, minPrice, maxPrice)
    outputResult(theta0, theta1)
}

为了正确实现梯度下降,我应该修复什么?


解决方案


Linear Regression 非常简单:

// yi = alpha + beta*xi + ei
func linearregression(x, y []float64) (float64, float64) {
    ex := expected(x)
    ey := expected(y)
    exy := expectedxy(x, y)
    exx := expectedxy(x, x)

    covariance := exy - ex*ey
    variance := exx - ex*ex
    beta := covariance / variance
    alpha := ey - beta*ex
    return alpha, beta
}

尝试一下 here,输出:

8499.599649933218 -0.021448963591702314 396270.87871142407

代码:

package main

import (
    "encoding/csv"
    "fmt"
    "strconv"
    "strings"
)

func main() {
    x, y := readxy(`data.csv`)
    alpha, beta := linearregression(x, y)
    fmt.println(alpha, beta, -alpha/beta) // 8499.599649933218 -0.021448963591702314 396270.87871142407
}

// https://en.wikipedia.org/wiki/ordinary_least_squares#simple_linear_regression_model
// yi = alpha + beta*xi + ei
func linearregression(x, y []float64) (float64, float64) {
    ex := expected(x)
    ey := expected(y)
    exy := expectedxy(x, y)
    exx := expectedxy(x, x)

    covariance := exy - ex*ey
    variance := exx - ex*ex
    beta := covariance / variance
    alpha := ey - beta*ex
    return alpha, beta
}

// e[x]
func expected(x []float64) float64 {
    sum := 0.0
    for _, v := range x {
        sum += v
    }
    return sum / float64(len(x))
}

// e[xy]
func expectedxy(x, y []float64) float64 {
    sum := 0.0
    for i, v := range x {
        sum += v * y[i]
    }
    return sum / float64(len(x))
}

func readxy(filename string) ([]float64, []float64) {
    // file, err := os.open(filename)
    // if err != nil {
    //  panic(err)
    // }
    // defer file.close()
    file := strings.newreader(data)

    reader := csv.newreader(file)
    records, err := reader.readall()
    if err != nil {
        panic(err)
    }
    records = records[1:]
    size := len(records)
    x := make([]float64, size)
    y := make([]float64, size)
    for i, v := range records {
        val, err := strconv.parsefloat(v[0], 64)
        if err != nil {
            panic(err)
        }
        x[i] = val
        val, err = strconv.parsefloat(v[1], 64)
        if err != nil {
            panic(err)
        }
        y[i] = val
    }
    return x, y
}

var data = `km,price
240000,3650
139800,3800
150500,4400
185530,4450
176000,5250
114800,5350
166800,5800
89000,5990
144500,5999
84000,6200
82029,6390
63060,6390
74000,6600
97500,6800
67000,6800
76025,6900
48235,6900
93000,6990
60949,7490
65674,7555
54000,7990
68500,7990
22899,7990
61789,8290`

Gradient descent 基于这样的观察:如果多变量函数 f(x) 被定义并且在点 a 的邻域中可微,那么如果从 azqba 沿fa,-∇f(a) 处的负梯度,例如:

// f(x)
f := func(x float64) float64 {
    return alpha + beta*x // write your target function here
}

导数函数:

h := 0.000001
// derivative function ∇f(x)
df := func(x float64) float64 {
    return (f(x+h) - f(x-h)) / (2 * h) // write your target function derivative here
}

搜索:

minimunAt := 1.0       // We start the search here
gamma := 0.01          // Step size multiplier
precision := 0.0000001 // Desired precision of result
max := 100000          // Maximum number of iterations
currentX := 0.0
step := 0.0
for i := 0; i < max; i++ {
    currentX = minimunAt
    minimunAt = currentX - gamma*df(currentX)
    step = minimunAt - currentX
    if math.Abs(step) <= precision {
        break
    }
}

fmt.Printf("Minimum at %.8f value: %v\n", minimunAt, f(minimunAt))

本篇关于《执行梯度降低》的介绍就到此结束啦,但是学无止境,想要了解学习更多关于Golang的相关知识,请关注golang学习网公众号!

版本声明
本文转载于:stackoverflow 如有侵犯,请联系study_golang@163.com删除
PHP7更新内容:解决未定义问题PHP7更新内容:解决未定义问题
上一篇
PHP7更新内容:解决未定义问题
在 GO 中获取重复匹配组
下一篇
在 GO 中获取重复匹配组
查看更多
最新文章
查看更多
课程推荐
  • 前端进阶之JavaScript设计模式
    前端进阶之JavaScript设计模式
    设计模式是开发人员在软件开发过程中面临一般问题时的解决方案,代表了最佳的实践。本课程的主打内容包括JS常见设计模式以及具体应用场景,打造一站式知识长龙服务,适合有JS基础的同学学习。
    542次学习
  • GO语言核心编程课程
    GO语言核心编程课程
    本课程采用真实案例,全面具体可落地,从理论到实践,一步一步将GO核心编程技术、编程思想、底层实现融会贯通,使学习者贴近时代脉搏,做IT互联网时代的弄潮儿。
    508次学习
  • 简单聊聊mysql8与网络通信
    简单聊聊mysql8与网络通信
    如有问题加微信:Le-studyg;在课程中,我们将首先介绍MySQL8的新特性,包括性能优化、安全增强、新数据类型等,帮助学生快速熟悉MySQL8的最新功能。接着,我们将深入解析MySQL的网络通信机制,包括协议、连接管理、数据传输等,让
    497次学习
  • JavaScript正则表达式基础与实战
    JavaScript正则表达式基础与实战
    在任何一门编程语言中,正则表达式,都是一项重要的知识,它提供了高效的字符串匹配与捕获机制,可以极大的简化程序设计。
    487次学习
  • 从零制作响应式网站—Grid布局
    从零制作响应式网站—Grid布局
    本系列教程将展示从零制作一个假想的网络科技公司官网,分为导航,轮播,关于我们,成功案例,服务流程,团队介绍,数据部分,公司动态,底部信息等内容区块。网站整体采用CSSGrid布局,支持响应式,有流畅过渡和展现动画。
    484次学习
查看更多
AI推荐
  • 可图AI图片生成:快手可灵AI2.0引领图像创作新时代
    可图AI图片生成
    探索快手旗下可灵AI2.0发布的可图AI2.0图像生成大模型,体验从文本生成图像、图像编辑到风格转绘的全链路创作。了解其技术突破、功能创新及在广告、影视、非遗等领域的应用,领先于Midjourney、DALL-E等竞品。
    14次使用
  • MeowTalk喵说:AI猫咪语言翻译,增进人猫情感交流
    MeowTalk喵说
    MeowTalk喵说是一款由Akvelon公司开发的AI应用,通过分析猫咪的叫声,帮助主人理解猫咪的需求和情感。支持iOS和Android平台,提供个性化翻译、情感互动、趣味对话等功能,增进人猫之间的情感联系。
    14次使用
  • SEO标题Traini:全球首创宠物AI技术,提升宠物健康与行为解读
    Traini
    SEO摘要Traini是一家专注于宠物健康教育的创新科技公司,利用先进的人工智能技术,提供宠物行为解读、个性化训练计划、在线课程、医疗辅助和个性化服务推荐等多功能服务。通过PEBI系统,Traini能够精准识别宠物狗的12种情绪状态,推动宠物与人类的智能互动,提升宠物生活质量。
    17次使用
  • 可图AI 2.0:快手旗下新一代图像生成大模型,专业创作者与普通用户的多模态创作引擎
    可图AI 2.0图片生成
    可图AI 2.0 是快手旗下的新一代图像生成大模型,支持文本生成图像、图像编辑、风格转绘等全链路创作需求。凭借DiT架构和MVL交互体系,提升了复杂语义理解和多模态交互能力,适用于广告、影视、非遗等领域,助力创作者高效创作。
    19次使用
  • 毕业宝AIGC检测:AI生成内容检测工具,助力学术诚信
    毕业宝AIGC检测
    毕业宝AIGC检测是“毕业宝”平台的AI生成内容检测工具,专为学术场景设计,帮助用户初步判断文本的原创性和AI参与度。通过与知网、维普数据库联动,提供全面检测结果,适用于学生、研究者、教育工作者及内容创作者。
    32次使用
微信登录更方便
  • 密码登录
  • 注册账号
登录即同意 用户协议隐私政策
返回登录
  • 重置密码