机器学习
This commit is contained in:
parent
6af21ee8e4
commit
5966b77d66
|
@ -10,7 +10,7 @@
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
<artifactId>start</artifactId>
|
<artifactId>start</artifactId>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
|
@ -21,6 +21,26 @@
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
<artifactId>spark-sql_2.13</artifactId>
|
<artifactId>spark-sql_2.13</artifactId>
|
||||||
<version>3.2.0</version>
|
<version>3.2.0</version>
|
||||||
|
<exclusions>
|
||||||
|
<exclusion>
|
||||||
|
<artifactId>janino</artifactId>
|
||||||
|
<groupId>org.codehaus.janino</groupId>
|
||||||
|
</exclusion>
|
||||||
|
<exclusion>
|
||||||
|
<artifactId>commons-compiler</artifactId>
|
||||||
|
<groupId>org.codehaus.janino</groupId>
|
||||||
|
</exclusion>
|
||||||
|
</exclusions>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<artifactId>janino</artifactId>
|
||||||
|
<groupId>org.codehaus.janino</groupId>
|
||||||
|
<version>3.0.8</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<artifactId>commons-compiler</artifactId>
|
||||||
|
<groupId>org.codehaus.janino</groupId>
|
||||||
|
<version>3.0.8</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
|
|
|
@ -3,7 +3,6 @@ package org.jeecg.sy.java;
|
||||||
import com.alibaba.fastjson.JSONArray;
|
import com.alibaba.fastjson.JSONArray;
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import org.jeecg.common.api.vo.Result;
|
import org.jeecg.common.api.vo.Result;
|
||||||
import org.jeecg.sy.scala.WeatherPrediction$;
|
|
||||||
import org.jeecg.sy.temp.analysis$;
|
import org.jeecg.sy.temp.analysis$;
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.springframework.data.redis.core.RedisTemplate;
|
import org.springframework.data.redis.core.RedisTemplate;
|
||||||
|
@ -24,6 +23,7 @@ public class Controller {
|
||||||
@Resource
|
@Resource
|
||||||
private RedisTemplate<String, Object> redisTemplate;
|
private RedisTemplate<String, Object> redisTemplate;
|
||||||
|
|
||||||
|
// 爬虫
|
||||||
@RequestMapping("/get")
|
@RequestMapping("/get")
|
||||||
public Result test1() throws MalformedURLException {
|
public Result test1() throws MalformedURLException {
|
||||||
return Result.ok(analysis$.MODULE$.analysis2(getData.getDatas("")));
|
return Result.ok(analysis$.MODULE$.analysis2(getData.getDatas("")));
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
package org.jeecg.yw;
|
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
|
||||||
import org.jeecg.common.api.vo.Result;
|
|
||||||
import org.jeecg.yw.ml.LinearRegression;
|
|
||||||
import org.jeecg.yw.spark.index;
|
|
||||||
import org.springframework.web.bind.annotation.RequestBody;
|
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@RestController
|
|
||||||
@RequestMapping("/analysis")
|
|
||||||
public class analysis {
|
|
||||||
@RequestMapping("/data")
|
|
||||||
public Result test(@RequestBody Map<String, Object> map) {
|
|
||||||
JSONObject json = new JSONObject();
|
|
||||||
json.put("healthScore", index.getScore(map));
|
|
||||||
json.put("foodRecommendations", LinearRegression.getResult(map));
|
|
||||||
return Result.ok(json);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
package org.jeecg.yw.java;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson.JSONObject;
|
||||||
|
import org.jeecg.common.api.vo.Result;
|
||||||
|
import org.jeecg.yw.spark.analysis$;
|
||||||
|
import org.springframework.web.bind.annotation.RequestBody;
|
||||||
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@RestController
|
||||||
|
@RequestMapping("/analysis")
|
||||||
|
public class index {
|
||||||
|
@RequestMapping("/data")
|
||||||
|
public Result getData(@RequestBody Map<String, Object> map) {
|
||||||
|
JSONObject json = new JSONObject();
|
||||||
|
json.put("healthScore", analysis$.MODULE$.getScore(map));
|
||||||
|
json.put("foodRecommendations", analysis$.MODULE$.getResult(map));
|
||||||
|
return Result.ok(json);
|
||||||
|
}
|
||||||
|
|
||||||
|
@RequestMapping("/ml")
|
||||||
|
public Result getScore(@RequestBody Map<String, Object> map) {
|
||||||
|
JSONObject json = new JSONObject();
|
||||||
|
json.put("ML", analysis$.MODULE$.getML(map));
|
||||||
|
return Result.ok(json);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,42 +0,0 @@
|
||||||
package org.jeecg.yw.ml;
|
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
public class LinearRegression {
|
|
||||||
// 获取结果
|
|
||||||
public static String getResult(Map<String, Object> map) {
|
|
||||||
String result[] = new String[9];
|
|
||||||
result[0] = "您的饮食缺少蛋白质,请多吃一些肉类、鱼类、蛋类、奶类等食物";
|
|
||||||
result[1] = "您的饮食缺少脂肪,请多吃一些肉类、鱼类、蛋类、奶类等食物";
|
|
||||||
result[2] = "您的饮食缺少碳水化合物,请多吃一些米面类食物";
|
|
||||||
result[3] = "您的饮食缺少维生素,请多吃一些水果、蔬菜等食物";
|
|
||||||
result[4] = "您的饮食缺少矿物质,请多吃一些水果、蔬菜等食物";
|
|
||||||
result[5] = "您的饮食缺少纤维,请多吃一些水果、蔬菜等食物";
|
|
||||||
result[6] = "您的饮食缺少水,请多喝水";
|
|
||||||
result[7] = "您的饮食过多,请适量减少食物摄入";
|
|
||||||
result[8] = "您的饮食过油腻,请适量减少食物摄入";
|
|
||||||
// 每条结果有30%的概率出现
|
|
||||||
Set<Integer> set = new java.util.HashSet<>();
|
|
||||||
// 随机9次,每次都是30%的概率,将结果加入set
|
|
||||||
for (int i = 0; i < 9; i++) {
|
|
||||||
int random = (int) (Math.random() * 10);
|
|
||||||
if (random < 3) {
|
|
||||||
set.add(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//拼接结果
|
|
||||||
String resultStr = "";
|
|
||||||
int j = 0;
|
|
||||||
for (int i : set) {
|
|
||||||
j++;
|
|
||||||
resultStr += j + "." + result[i] + "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
return resultStr;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
package org.jeecg.yw.spark
|
||||||
|
|
||||||
|
import org.apache.spark.ml.PipelineModel
|
||||||
|
import org.apache.spark.sql.SparkSession
|
||||||
|
import org.apache.spark.sql.types.{StringType, StructField, StructType}
|
||||||
|
|
||||||
|
import java.util
|
||||||
|
import java.util.{HashSet, Map, Set}
|
||||||
|
import scala.collection.convert.ImplicitConversions.{`collection AsScalaIterable`, `map AsScala`}
|
||||||
|
|
||||||
|
object analysis {
|
||||||
|
|
||||||
|
import org.apache.spark.sql.Row
|
||||||
|
|
||||||
|
// 机器学习预测
|
||||||
|
def getML(map: Map[String, AnyRef]): Double = {
|
||||||
|
val spark = SparkSession.builder()
|
||||||
|
.appName("Weather Temperature Prediction")
|
||||||
|
.master("local[*]")
|
||||||
|
.getOrCreate()
|
||||||
|
|
||||||
|
// 加载模型
|
||||||
|
val model = PipelineModel.load("hdfs://192.168.192.100:8020/Model")
|
||||||
|
|
||||||
|
// 从map中获取键值对,并创建一个新的DataFrame
|
||||||
|
val row = Row.fromSeq(map.values.toSeq)
|
||||||
|
val df = spark.createDataFrame(
|
||||||
|
spark.sparkContext.parallelize(Seq(row)),
|
||||||
|
StructType(map.keys.toSeq.map(fieldName => StructField(fieldName, StringType)))
|
||||||
|
)
|
||||||
|
|
||||||
|
// 使用模型进行预测
|
||||||
|
val predictions = model.transform(df)
|
||||||
|
predictions.select("prediction").show()
|
||||||
|
spark.stop()
|
||||||
|
val prediction = predictions.select("prediction").collect()(0)(0).toString.toDouble
|
||||||
|
prediction
|
||||||
|
}
|
||||||
|
|
||||||
|
def getScore(map: Map[String, AnyRef]): Double = {
|
||||||
|
val age: Int = map.get("age").toString.toInt
|
||||||
|
val weight: Int = map.get("weight").toString.toInt
|
||||||
|
val sleepTime: Int = map.get("sleepTime").toString.toInt
|
||||||
|
val tableDataObject: AnyRef = map.get("tableData")
|
||||||
|
var Score: Double = 100
|
||||||
|
|
||||||
|
import scala.jdk.CollectionConverters._
|
||||||
|
|
||||||
|
val tableDataList: List[Map[String, AnyRef]] = tableDataObject.asInstanceOf[java.util.ArrayList[java.util.Map[String, AnyRef]]].asScala.toList
|
||||||
|
// 遍历列表
|
||||||
|
for (item <- tableDataList) {
|
||||||
|
val `type`: String = item.get("type").toString
|
||||||
|
val num: Int = item.get("num").toString.toInt
|
||||||
|
if (`type` == ("绿叶蔬菜") || `type` == ("红橙色蔬菜") || `type` == ("土豆") || `type` == ("其他蔬菜类") || `type` == ("薯类") || `type` == ("水果") || `type` == ("大豆制品") || `type` == ("新鲜肉类") || `type` == ("鱼虾或其他海鲜") || `type` == ("蛋类") || `type` == ("奶类")) {
|
||||||
|
Score += 3.25 * (10 + num) / 10
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
Score -= 10.25 * (10 + num) / 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var temp: Double = 0
|
||||||
|
// 计算年龄
|
||||||
|
Score = (Score * (100 + age - 10) / 100).toInt
|
||||||
|
if (weight > 60) {
|
||||||
|
temp = 60 - weight
|
||||||
|
} else if (weight < 40) {
|
||||||
|
temp = 40 - weight
|
||||||
|
} else {
|
||||||
|
temp = 0
|
||||||
|
}
|
||||||
|
Score = (Score * (100 - temp) / 100).toInt
|
||||||
|
if (sleepTime > 12) {
|
||||||
|
temp = sleepTime - 12
|
||||||
|
} else if (sleepTime < 7) {
|
||||||
|
temp = 7 - sleepTime
|
||||||
|
} else {
|
||||||
|
temp = 0
|
||||||
|
}
|
||||||
|
Score = (Score * (100 + sleepTime - 8) / 100).toInt
|
||||||
|
if (Score > 97) {
|
||||||
|
Score = 97
|
||||||
|
}
|
||||||
|
return Score
|
||||||
|
}
|
||||||
|
|
||||||
|
def getResult(map: util.Map[String, AnyRef]): String = {
|
||||||
|
val result: Array[String] = new Array[String](9)
|
||||||
|
result(0) = "您的饮食缺少蛋白质,请多吃一些肉类、鱼类、蛋类、奶类等食物"
|
||||||
|
result(1) = "您的饮食缺少脂肪,请多吃一些肉类、鱼类、蛋类、奶类等食物"
|
||||||
|
result(2) = "您的饮食缺少碳水化合物,请多吃一些米面类食物"
|
||||||
|
result(3) = "您的饮食缺少维生素,请多吃一些水果、蔬菜等食物"
|
||||||
|
result(4) = "您的饮食缺少矿物质,请多吃一些水果、蔬菜等食物"
|
||||||
|
result(5) = "您的饮食缺少纤维,请多吃一些水果、蔬菜等食物"
|
||||||
|
result(6) = "您的饮食缺少水,请多喝水"
|
||||||
|
result(7) = "您的饮食过多,请适量减少食物摄入"
|
||||||
|
result(8) = "您的饮食过油腻,请适量减少食物摄入"
|
||||||
|
val set: util.Set[Integer] = new util.HashSet[Integer]
|
||||||
|
for (i <- 0 until 9) {
|
||||||
|
val random: Int = (Math.random * 10).toInt
|
||||||
|
if (random < 3) set.add(i)
|
||||||
|
}
|
||||||
|
//拼接结果
|
||||||
|
var resultStr: String = ""
|
||||||
|
var j: Int = 0
|
||||||
|
for (i <- 0 until set.size) {
|
||||||
|
j += 1
|
||||||
|
resultStr += j + "." + result(i) + "\n"
|
||||||
|
}
|
||||||
|
resultStr
|
||||||
|
}
|
||||||
|
|
||||||
|
def getRecord(): Double = {
|
||||||
|
// 一个70-90的随机数
|
||||||
|
val random: Double = (Math.random * 10).toInt
|
||||||
|
random
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,59 +0,0 @@
|
||||||
package org.jeecg.yw.spark;
|
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class index {
|
|
||||||
// 获取分数
|
|
||||||
public static double getScore(Map<String, Object> map) {
|
|
||||||
// 值分为age,weight,sleepTime,tableData(是一个json数组)
|
|
||||||
//计算规则如下,满分为100,其中年龄只做一个浮动值,从7-18岁,算法为总分*(100+岁数-10)%,即十岁以上的总分会比较高
|
|
||||||
//体重,每减少1kg,总分减少1%,每增加1kg,总分增加1%,便准体重为
|
|
||||||
//睡眠时间,每减少1小时,总分减少1%,每增加1小时,总分增加1%,便准睡眠时间为8小时
|
|
||||||
//获取年龄,体重,睡眠时间
|
|
||||||
int age = Integer.parseInt(map.get("age").toString());
|
|
||||||
int weight = Integer.parseInt(map.get("weight").toString());
|
|
||||||
int sleepTime = Integer.parseInt(map.get("sleepTime").toString());
|
|
||||||
// 检查tableData是否是List类型
|
|
||||||
Object tableDataObject = map.get("tableData");
|
|
||||||
String tableData = map.get("tableData").toString();
|
|
||||||
// 将tableData转为map
|
|
||||||
//遍历循环json数组
|
|
||||||
// 好的
|
|
||||||
// 绿叶蔬菜 红橙色蔬菜 土豆 其他蔬菜类 薯类
|
|
||||||
// 水果 大豆制品 新鲜肉类 鱼虾或其他海鲜 蛋类 奶类
|
|
||||||
// 坏的
|
|
||||||
// 方便面西式快餐 加糖饮料 加糖或盐的零食和甜点 油炸食品 加工肉类
|
|
||||||
// 一共有16个字段,数据结构为[{type=绿叶蔬菜, num=1},.. 其中数字为1-5
|
|
||||||
// 计算方式为基础分数为100,每个好的食物加N*(1.X)分,每个坏的食物减N*(1.X)分 其中X是num,N是100/16=6.25
|
|
||||||
double Score = 100;
|
|
||||||
if (tableDataObject instanceof List) {
|
|
||||||
List<Map<String, Object>> tableDataList = (List<Map<String, Object>>) tableDataObject;
|
|
||||||
// 遍历列表
|
|
||||||
for (Map<String, Object> item : tableDataList) {
|
|
||||||
// 例如,获取type和num
|
|
||||||
String type = item.get("type").toString();
|
|
||||||
int num = Integer.parseInt(item.get("num").toString());
|
|
||||||
if (type.equals("绿叶蔬菜") || type.equals("红橙色蔬菜") || type.equals("土豆") || type.equals("其他蔬菜类") || type.equals("薯类") || type.equals("水果") || type.equals("大豆制品") || type.equals("新鲜肉类") || type.equals("鱼虾或其他海鲜") || type.equals("蛋类") || type.equals("奶类")) {
|
|
||||||
Score += 6.25 * (10 + num) / 10;
|
|
||||||
} else {
|
|
||||||
Score -= 9.25 * (10 + num) / 10;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 计算年龄
|
|
||||||
Score = (int) (Score * (100 + age - 10) / 100);
|
|
||||||
// 计算体重
|
|
||||||
Score = (int) (Score * (100 + weight - 60) / 100);
|
|
||||||
// 计算睡眠时间
|
|
||||||
Score = (int) (Score * (100 + sleepTime - 8) / 100);
|
|
||||||
// 最高不超过97
|
|
||||||
if (Score > 97) {
|
|
||||||
Score = 97;
|
|
||||||
}
|
|
||||||
return Score;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
package org.jeecg.yw.spark
|
||||||
|
|
||||||
|
import org.apache.spark.ml.PipelineModel
|
||||||
|
import org.apache.spark.sql.SparkSession
|
||||||
|
|
||||||
|
object test {
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
val spark = SparkSession.builder()
|
||||||
|
.appName("Weather Temperature Prediction")
|
||||||
|
.master("local[*]")
|
||||||
|
.getOrCreate()
|
||||||
|
|
||||||
|
// 加载模型
|
||||||
|
val model = PipelineModel.load("hdfs://192.168.192.100:8020/Model")
|
||||||
|
|
||||||
|
val newData = spark.createDataFrame(Seq(
|
||||||
|
(5.0, 3.0, 2.0, 4.0)
|
||||||
|
)).toDF("绿叶蔬菜", "水果", "大豆制品", "新鲜肉类")
|
||||||
|
|
||||||
|
// 使用模型进行预测
|
||||||
|
val predictions = model.transform(newData)
|
||||||
|
predictions.select("prediction").show()
|
||||||
|
|
||||||
|
spark.stop()
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
package org.jeecg.yw.spark
|
||||||
|
|
||||||
|
import org.apache.spark.sql.SparkSession
|
||||||
|
import org.apache.spark.sql.functions.udf
|
||||||
|
import org.apache.spark.ml.feature.{VectorAssembler, StandardScaler}
|
||||||
|
import org.apache.spark.ml.classification.LogisticRegression
|
||||||
|
import org.apache.spark.ml.{Pipeline, PipelineModel}
|
||||||
|
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
|
||||||
|
|
||||||
|
object train {
|
||||||
|
def main(args: Array[String]): Unit = {
|
||||||
|
// 设置HADOOP_USER_NAME环境变量
|
||||||
|
System.setProperty("HADOOP_USER_NAME", "root")
|
||||||
|
// 创建Spark会话(win本地)
|
||||||
|
val spark = SparkSession.builder()
|
||||||
|
.appName("Weather Temperature Prediction")
|
||||||
|
.master("local[*]")
|
||||||
|
.getOrCreate()
|
||||||
|
// 加载数据
|
||||||
|
val data = spark.read.option("header", "true").option("inferSchema", "true")
|
||||||
|
.csv("C:\\Users\\23972\\Desktop\\Completion-template-cmd\\system\\start\\src\\main\\java\\org\\jeecg\\yw\\spark\\青少年膳食营养数据集.csv")
|
||||||
|
// 将健康评分转为分类标签的UDF
|
||||||
|
val categorizeHealth = udf((score: Double) => score match {
|
||||||
|
case score if score <= 33 => 0
|
||||||
|
case score if score <= 66 => 1
|
||||||
|
case _ => 2
|
||||||
|
})
|
||||||
|
// 应用UDF
|
||||||
|
val processedData = data.withColumn("HealthCategory", categorizeHealth(data("健康自评")))
|
||||||
|
val assembler = new VectorAssembler()
|
||||||
|
.setInputCols(Array("绿叶蔬菜", "水果", "大豆制品", "新鲜肉类"))
|
||||||
|
.setOutputCol("features")
|
||||||
|
// 特征标准化
|
||||||
|
val scaler = new StandardScaler()
|
||||||
|
.setInputCol("features")
|
||||||
|
.setOutputCol("scaledFeatures")
|
||||||
|
// 初始化逻辑回归模型
|
||||||
|
val lr = new LogisticRegression()
|
||||||
|
.setFeaturesCol("scaledFeatures")
|
||||||
|
.setLabelCol("HealthCategory")
|
||||||
|
// 创建Pipeline
|
||||||
|
val pipeline = new Pipeline()
|
||||||
|
.setStages(Array(assembler, scaler, lr))
|
||||||
|
// 划分数据集
|
||||||
|
val Array(trainingData, testData) = processedData.randomSplit(Array(0.8, 0.2), seed = 1234L)
|
||||||
|
// 训练模型
|
||||||
|
val model = pipeline.fit(trainingData)
|
||||||
|
// 预测
|
||||||
|
val predictions = model.transform(testData)
|
||||||
|
// 模型评估
|
||||||
|
val evaluator = new MulticlassClassificationEvaluator()
|
||||||
|
.setLabelCol("HealthCategory")
|
||||||
|
.setPredictionCol("prediction")
|
||||||
|
.setMetricName("accuracy")
|
||||||
|
|
||||||
|
val accuracy = evaluator.evaluate(predictions)
|
||||||
|
println(s"Accuracy = $accuracy")
|
||||||
|
// 保存在 hdfs路径
|
||||||
|
model.write.overwrite().save("hdfs://192.168.192.100:8020/Model")
|
||||||
|
// 关闭Spark会话
|
||||||
|
spark.stop()
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue