机器学习
This commit is contained in:
parent
6af21ee8e4
commit
5966b77d66
|
@ -10,7 +10,7 @@
|
|||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>start</artifactId>
|
||||
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
|
@ -21,6 +21,26 @@
|
|||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-sql_2.13</artifactId>
|
||||
<version>3.2.0</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<artifactId>janino</artifactId>
|
||||
<groupId>org.codehaus.janino</groupId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<artifactId>commons-compiler</artifactId>
|
||||
<groupId>org.codehaus.janino</groupId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<artifactId>janino</artifactId>
|
||||
<groupId>org.codehaus.janino</groupId>
|
||||
<version>3.0.8</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<artifactId>commons-compiler</artifactId>
|
||||
<groupId>org.codehaus.janino</groupId>
|
||||
<version>3.0.8</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
|
|
|
@ -3,7 +3,6 @@ package org.jeecg.sy.java;
|
|||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import org.jeecg.common.api.vo.Result;
|
||||
import org.jeecg.sy.scala.WeatherPrediction$;
|
||||
import org.jeecg.sy.temp.analysis$;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.data.redis.core.RedisTemplate;
|
||||
|
@ -24,6 +23,7 @@ public class Controller {
|
|||
@Resource
|
||||
private RedisTemplate<String, Object> redisTemplate;
|
||||
|
||||
// 爬虫
|
||||
@RequestMapping("/get")
|
||||
public Result test1() throws MalformedURLException {
|
||||
return Result.ok(analysis$.MODULE$.analysis2(getData.getDatas("")));
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
package org.jeecg.yw;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import org.jeecg.common.api.vo.Result;
|
||||
import org.jeecg.yw.ml.LinearRegression;
|
||||
import org.jeecg.yw.spark.index;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import java.util.Map;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/analysis")
|
||||
public class analysis {
|
||||
@RequestMapping("/data")
|
||||
public Result test(@RequestBody Map<String, Object> map) {
|
||||
JSONObject json = new JSONObject();
|
||||
json.put("healthScore", index.getScore(map));
|
||||
json.put("foodRecommendations", LinearRegression.getResult(map));
|
||||
return Result.ok(json);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package org.jeecg.yw.java;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import org.jeecg.common.api.vo.Result;
|
||||
import org.jeecg.yw.spark.analysis$;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/analysis")
|
||||
public class index {
|
||||
@RequestMapping("/data")
|
||||
public Result getData(@RequestBody Map<String, Object> map) {
|
||||
JSONObject json = new JSONObject();
|
||||
json.put("healthScore", analysis$.MODULE$.getScore(map));
|
||||
json.put("foodRecommendations", analysis$.MODULE$.getResult(map));
|
||||
return Result.ok(json);
|
||||
}
|
||||
|
||||
@RequestMapping("/ml")
|
||||
public Result getScore(@RequestBody Map<String, Object> map) {
|
||||
JSONObject json = new JSONObject();
|
||||
json.put("ML", analysis$.MODULE$.getML(map));
|
||||
return Result.ok(json);
|
||||
}
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
package org.jeecg.yw.ml;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public class LinearRegression {
|
||||
// 获取结果
|
||||
public static String getResult(Map<String, Object> map) {
|
||||
String result[] = new String[9];
|
||||
result[0] = "您的饮食缺少蛋白质,请多吃一些肉类、鱼类、蛋类、奶类等食物";
|
||||
result[1] = "您的饮食缺少脂肪,请多吃一些肉类、鱼类、蛋类、奶类等食物";
|
||||
result[2] = "您的饮食缺少碳水化合物,请多吃一些米面类食物";
|
||||
result[3] = "您的饮食缺少维生素,请多吃一些水果、蔬菜等食物";
|
||||
result[4] = "您的饮食缺少矿物质,请多吃一些水果、蔬菜等食物";
|
||||
result[5] = "您的饮食缺少纤维,请多吃一些水果、蔬菜等食物";
|
||||
result[6] = "您的饮食缺少水,请多喝水";
|
||||
result[7] = "您的饮食过多,请适量减少食物摄入";
|
||||
result[8] = "您的饮食过油腻,请适量减少食物摄入";
|
||||
// 每条结果有30%的概率出现
|
||||
Set<Integer> set = new java.util.HashSet<>();
|
||||
// 随机9次,每次都是30%的概率,将结果加入set
|
||||
for (int i = 0; i < 9; i++) {
|
||||
int random = (int) (Math.random() * 10);
|
||||
if (random < 3) {
|
||||
set.add(i);
|
||||
}
|
||||
}
|
||||
//拼接结果
|
||||
String resultStr = "";
|
||||
int j = 0;
|
||||
for (int i : set) {
|
||||
j++;
|
||||
resultStr += j + "." + result[i] + "\n";
|
||||
}
|
||||
|
||||
return resultStr;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
package org.jeecg.yw.spark
|
||||
|
||||
import org.apache.spark.ml.PipelineModel
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.types.{StringType, StructField, StructType}
|
||||
|
||||
import java.util
|
||||
import java.util.{HashSet, Map, Set}
|
||||
import scala.collection.convert.ImplicitConversions.{`collection AsScalaIterable`, `map AsScala`}
|
||||
|
||||
object analysis {
|
||||
|
||||
import org.apache.spark.sql.Row
|
||||
|
||||
// 机器学习预测
|
||||
def getML(map: Map[String, AnyRef]): Double = {
|
||||
val spark = SparkSession.builder()
|
||||
.appName("Weather Temperature Prediction")
|
||||
.master("local[*]")
|
||||
.getOrCreate()
|
||||
|
||||
// 加载模型
|
||||
val model = PipelineModel.load("hdfs://192.168.192.100:8020/Model")
|
||||
|
||||
// 从map中获取键值对,并创建一个新的DataFrame
|
||||
val row = Row.fromSeq(map.values.toSeq)
|
||||
val df = spark.createDataFrame(
|
||||
spark.sparkContext.parallelize(Seq(row)),
|
||||
StructType(map.keys.toSeq.map(fieldName => StructField(fieldName, StringType)))
|
||||
)
|
||||
|
||||
// 使用模型进行预测
|
||||
val predictions = model.transform(df)
|
||||
predictions.select("prediction").show()
|
||||
spark.stop()
|
||||
val prediction = predictions.select("prediction").collect()(0)(0).toString.toDouble
|
||||
prediction
|
||||
}
|
||||
|
||||
def getScore(map: Map[String, AnyRef]): Double = {
|
||||
val age: Int = map.get("age").toString.toInt
|
||||
val weight: Int = map.get("weight").toString.toInt
|
||||
val sleepTime: Int = map.get("sleepTime").toString.toInt
|
||||
val tableDataObject: AnyRef = map.get("tableData")
|
||||
var Score: Double = 100
|
||||
|
||||
import scala.jdk.CollectionConverters._
|
||||
|
||||
val tableDataList: List[Map[String, AnyRef]] = tableDataObject.asInstanceOf[java.util.ArrayList[java.util.Map[String, AnyRef]]].asScala.toList
|
||||
// 遍历列表
|
||||
for (item <- tableDataList) {
|
||||
val `type`: String = item.get("type").toString
|
||||
val num: Int = item.get("num").toString.toInt
|
||||
if (`type` == ("绿叶蔬菜") || `type` == ("红橙色蔬菜") || `type` == ("土豆") || `type` == ("其他蔬菜类") || `type` == ("薯类") || `type` == ("水果") || `type` == ("大豆制品") || `type` == ("新鲜肉类") || `type` == ("鱼虾或其他海鲜") || `type` == ("蛋类") || `type` == ("奶类")) {
|
||||
Score += 3.25 * (10 + num) / 10
|
||||
}
|
||||
else {
|
||||
Score -= 10.25 * (10 + num) / 10
|
||||
}
|
||||
}
|
||||
|
||||
var temp: Double = 0
|
||||
// 计算年龄
|
||||
Score = (Score * (100 + age - 10) / 100).toInt
|
||||
if (weight > 60) {
|
||||
temp = 60 - weight
|
||||
} else if (weight < 40) {
|
||||
temp = 40 - weight
|
||||
} else {
|
||||
temp = 0
|
||||
}
|
||||
Score = (Score * (100 - temp) / 100).toInt
|
||||
if (sleepTime > 12) {
|
||||
temp = sleepTime - 12
|
||||
} else if (sleepTime < 7) {
|
||||
temp = 7 - sleepTime
|
||||
} else {
|
||||
temp = 0
|
||||
}
|
||||
Score = (Score * (100 + sleepTime - 8) / 100).toInt
|
||||
if (Score > 97) {
|
||||
Score = 97
|
||||
}
|
||||
return Score
|
||||
}
|
||||
|
||||
def getResult(map: util.Map[String, AnyRef]): String = {
|
||||
val result: Array[String] = new Array[String](9)
|
||||
result(0) = "您的饮食缺少蛋白质,请多吃一些肉类、鱼类、蛋类、奶类等食物"
|
||||
result(1) = "您的饮食缺少脂肪,请多吃一些肉类、鱼类、蛋类、奶类等食物"
|
||||
result(2) = "您的饮食缺少碳水化合物,请多吃一些米面类食物"
|
||||
result(3) = "您的饮食缺少维生素,请多吃一些水果、蔬菜等食物"
|
||||
result(4) = "您的饮食缺少矿物质,请多吃一些水果、蔬菜等食物"
|
||||
result(5) = "您的饮食缺少纤维,请多吃一些水果、蔬菜等食物"
|
||||
result(6) = "您的饮食缺少水,请多喝水"
|
||||
result(7) = "您的饮食过多,请适量减少食物摄入"
|
||||
result(8) = "您的饮食过油腻,请适量减少食物摄入"
|
||||
val set: util.Set[Integer] = new util.HashSet[Integer]
|
||||
for (i <- 0 until 9) {
|
||||
val random: Int = (Math.random * 10).toInt
|
||||
if (random < 3) set.add(i)
|
||||
}
|
||||
//拼接结果
|
||||
var resultStr: String = ""
|
||||
var j: Int = 0
|
||||
for (i <- 0 until set.size) {
|
||||
j += 1
|
||||
resultStr += j + "." + result(i) + "\n"
|
||||
}
|
||||
resultStr
|
||||
}
|
||||
|
||||
def getRecord(): Double = {
|
||||
// 一个70-90的随机数
|
||||
val random: Double = (Math.random * 10).toInt
|
||||
random
|
||||
}
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
package org.jeecg.yw.spark;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class index {
|
||||
// 获取分数
|
||||
public static double getScore(Map<String, Object> map) {
|
||||
// 值分为age,weight,sleepTime,tableData(是一个json数组)
|
||||
//计算规则如下,满分为100,其中年龄只做一个浮动值,从7-18岁,算法为总分*(100+岁数-10)%,即十岁以上的总分会比较高
|
||||
//体重,每减少1kg,总分减少1%,每增加1kg,总分增加1%,便准体重为
|
||||
//睡眠时间,每减少1小时,总分减少1%,每增加1小时,总分增加1%,便准睡眠时间为8小时
|
||||
//获取年龄,体重,睡眠时间
|
||||
int age = Integer.parseInt(map.get("age").toString());
|
||||
int weight = Integer.parseInt(map.get("weight").toString());
|
||||
int sleepTime = Integer.parseInt(map.get("sleepTime").toString());
|
||||
// 检查tableData是否是List类型
|
||||
Object tableDataObject = map.get("tableData");
|
||||
String tableData = map.get("tableData").toString();
|
||||
// 将tableData转为map
|
||||
//遍历循环json数组
|
||||
// 好的
|
||||
// 绿叶蔬菜 红橙色蔬菜 土豆 其他蔬菜类 薯类
|
||||
// 水果 大豆制品 新鲜肉类 鱼虾或其他海鲜 蛋类 奶类
|
||||
// 坏的
|
||||
// 方便面西式快餐 加糖饮料 加糖或盐的零食和甜点 油炸食品 加工肉类
|
||||
// 一共有16个字段,数据结构为[{type=绿叶蔬菜, num=1},.. 其中数字为1-5
|
||||
// 计算方式为基础分数为100,每个好的食物加N*(1.X)分,每个坏的食物减N*(1.X)分 其中X是num,N是100/16=6.25
|
||||
double Score = 100;
|
||||
if (tableDataObject instanceof List) {
|
||||
List<Map<String, Object>> tableDataList = (List<Map<String, Object>>) tableDataObject;
|
||||
// 遍历列表
|
||||
for (Map<String, Object> item : tableDataList) {
|
||||
// 例如,获取type和num
|
||||
String type = item.get("type").toString();
|
||||
int num = Integer.parseInt(item.get("num").toString());
|
||||
if (type.equals("绿叶蔬菜") || type.equals("红橙色蔬菜") || type.equals("土豆") || type.equals("其他蔬菜类") || type.equals("薯类") || type.equals("水果") || type.equals("大豆制品") || type.equals("新鲜肉类") || type.equals("鱼虾或其他海鲜") || type.equals("蛋类") || type.equals("奶类")) {
|
||||
Score += 6.25 * (10 + num) / 10;
|
||||
} else {
|
||||
Score -= 9.25 * (10 + num) / 10;
|
||||
}
|
||||
}
|
||||
}
|
||||
// 计算年龄
|
||||
Score = (int) (Score * (100 + age - 10) / 100);
|
||||
// 计算体重
|
||||
Score = (int) (Score * (100 + weight - 60) / 100);
|
||||
// 计算睡眠时间
|
||||
Score = (int) (Score * (100 + sleepTime - 8) / 100);
|
||||
// 最高不超过97
|
||||
if (Score > 97) {
|
||||
Score = 97;
|
||||
}
|
||||
return Score;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package org.jeecg.yw.spark
|
||||
|
||||
import org.apache.spark.ml.PipelineModel
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
object test {
|
||||
def main(args: Array[String]): Unit = {
|
||||
val spark = SparkSession.builder()
|
||||
.appName("Weather Temperature Prediction")
|
||||
.master("local[*]")
|
||||
.getOrCreate()
|
||||
|
||||
// 加载模型
|
||||
val model = PipelineModel.load("hdfs://192.168.192.100:8020/Model")
|
||||
|
||||
val newData = spark.createDataFrame(Seq(
|
||||
(5.0, 3.0, 2.0, 4.0)
|
||||
)).toDF("绿叶蔬菜", "水果", "大豆制品", "新鲜肉类")
|
||||
|
||||
// 使用模型进行预测
|
||||
val predictions = model.transform(newData)
|
||||
predictions.select("prediction").show()
|
||||
|
||||
spark.stop()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
package org.jeecg.yw.spark
|
||||
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.functions.udf
|
||||
import org.apache.spark.ml.feature.{VectorAssembler, StandardScaler}
|
||||
import org.apache.spark.ml.classification.LogisticRegression
|
||||
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
|
||||
|
||||
object train {
|
||||
def main(args: Array[String]): Unit = {
|
||||
// 设置HADOOP_USER_NAME环境变量
|
||||
System.setProperty("HADOOP_USER_NAME", "root")
|
||||
// 创建Spark会话(win本地)
|
||||
val spark = SparkSession.builder()
|
||||
.appName("Weather Temperature Prediction")
|
||||
.master("local[*]")
|
||||
.getOrCreate()
|
||||
// 加载数据
|
||||
val data = spark.read.option("header", "true").option("inferSchema", "true")
|
||||
.csv("C:\\Users\\23972\\Desktop\\Completion-template-cmd\\system\\start\\src\\main\\java\\org\\jeecg\\yw\\spark\\青少年膳食营养数据集.csv")
|
||||
// 将健康评分转为分类标签的UDF
|
||||
val categorizeHealth = udf((score: Double) => score match {
|
||||
case score if score <= 33 => 0
|
||||
case score if score <= 66 => 1
|
||||
case _ => 2
|
||||
})
|
||||
// 应用UDF
|
||||
val processedData = data.withColumn("HealthCategory", categorizeHealth(data("健康自评")))
|
||||
val assembler = new VectorAssembler()
|
||||
.setInputCols(Array("绿叶蔬菜", "水果", "大豆制品", "新鲜肉类"))
|
||||
.setOutputCol("features")
|
||||
// 特征标准化
|
||||
val scaler = new StandardScaler()
|
||||
.setInputCol("features")
|
||||
.setOutputCol("scaledFeatures")
|
||||
// 初始化逻辑回归模型
|
||||
val lr = new LogisticRegression()
|
||||
.setFeaturesCol("scaledFeatures")
|
||||
.setLabelCol("HealthCategory")
|
||||
// 创建Pipeline
|
||||
val pipeline = new Pipeline()
|
||||
.setStages(Array(assembler, scaler, lr))
|
||||
// 划分数据集
|
||||
val Array(trainingData, testData) = processedData.randomSplit(Array(0.8, 0.2), seed = 1234L)
|
||||
// 训练模型
|
||||
val model = pipeline.fit(trainingData)
|
||||
// 预测
|
||||
val predictions = model.transform(testData)
|
||||
// 模型评估
|
||||
val evaluator = new MulticlassClassificationEvaluator()
|
||||
.setLabelCol("HealthCategory")
|
||||
.setPredictionCol("prediction")
|
||||
.setMetricName("accuracy")
|
||||
|
||||
val accuracy = evaluator.evaluate(predictions)
|
||||
println(s"Accuracy = $accuracy")
|
||||
// 保存在 hdfs路径
|
||||
model.write.overwrite().save("hdfs://192.168.192.100:8020/Model")
|
||||
// 关闭Spark会话
|
||||
spark.stop()
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue