文章目录
一.在百度智能云申请获取APIKey和SecretKey
baidu:
ai:
APIKey: xxxxxxx
SecretKey: xxxxxxxx
二.基础工具类
1.获取访问token工具类
//通过APIKey和SecretKey获取
@Slf4j
public class BaiduAiAuthUtil {
// 参考:https://ai.baidu.com/ai-doc/REFERENCE/Ck3dwjhhu
/** * 获取API访问token * 该token有一定的有效期,需要自行管理,当失效时需重新获取. * * @param ak - 百度云官网获取的 API Key * @param sk - 百度云官网获取的 Secret Key * @return assess_token 示例: * { * "access_token": "24.460da4889caad24cccdb1fea17221975.2592000.1491995545.282335-1234567", * "expires_in": 2592000 * } */
public static String getAuth(String ak, String sk) {
// 获取token地址
String authHost = "https://aip.baidubce.com/oauth/2.0/token?";
String getAccessTokenUrl = authHost
// 1. grant_type为固定参数
+ "grant_type=client_credentials"
// 2. 官网获取的 API Key
+ "&client_id=" + ak
// 3. 官网获取的 Secret Key
+ "&client_secret=" + sk;
try {
URL realUrl = new URL(getAccessTokenUrl);
// 打开和URL之间的连接
HttpURLConnection connection = (HttpURLConnection) realUrl.openConnection();
connection.setRequestMethod("GET");
connection.connect();
// 获取所有响应头字段
Map<String, List<String>> map = connection.getHeaderFields();
// 遍历所有的响应头字段
for (String key : map.keySet()) {
System.err.println(key + "--->" + map.get(key));
}
// 定义 BufferedReader输入流来读取URL的响应
BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream()));
String result = "";
String line;
while ((line = in.readLine()) != null) {
result += line;
}
/** * 返回结果示例 */
System.err.println("result:" + result);
JSONObject jsonObject = new JSONObject(result);
String access_token = jsonObject.getString("access_token");
return access_token;
} catch (Exception e) {
log.error("获取token失败!错误原因:{}", e.getMessage());
e.printStackTrace(System.err);
}
return null;
}
}
存redis中
public String getBaiDuAuth() {
String auth = (String) redisTemplate.opsForValue().get(BAIDU_AI_IMAGE_ASSESS_TOKEN_KEY);
if (StringUtils.isBlank(auth)) {
auth = BaiduAiAuthUtil.getAuth(APIKey, SecretKey);
if (StringUtils.isNotBlank(auth)) {
// 官方返回30天过期,考虑延迟,缓存时长减少一小时
redisTemplate.opsForValue().set(BAIDU_AI_IMAGE_ASSESS_TOKEN_KEY, auth, 2592000 - 3600, TimeUnit.SECONDS);
}
}
return auth;
}
2.Base64转换工具类
/** * Base64 工具类 */
public class Base64Util {
private static final char last2byte = (char) Integer.parseInt("00000011", 2);
private static final char last4byte = (char) Integer.parseInt("00001111", 2);
private static final char last6byte = (char) Integer.parseInt("00111111", 2);
private static final char lead6byte = (char) Integer.parseInt("11111100", 2);
private static final char lead4byte = (char) Integer.parseInt("11110000", 2);
private static final char lead2byte = (char) Integer.parseInt("11000000", 2);
private static final char[] encodeTable = new char[]{
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'};
public Base64Util() {
}
public static String encode(byte[] from) {
StringBuilder to = new StringBuilder((int) ((double) from.length * 1.34D) + 3);
int num = 0;
char currentByte = 0;
int i;
for (i = 0; i < from.length; ++i) {
for (num %= 8; num < 8; num += 6) {
switch (num) {
case 0:
currentByte = (char) (from[i] & lead6byte);
currentByte = (char) (currentByte >>> 2);
case 1:
case 3:
case 5:
default:
break;
case 2:
currentByte = (char) (from[i] & last6byte);
break;
case 4:
currentByte = (char) (from[i] & last4byte);
currentByte = (char) (currentByte << 2);
if (i + 1 < from.length) {
currentByte = (char) (currentByte | (from[i + 1] & lead2byte) >>> 6);
}
break;
case 6:
currentByte = (char) (from[i] & last2byte);
currentByte = (char) (currentByte << 4);
if (i + 1 < from.length) {
currentByte = (char) (currentByte | (from[i + 1] & lead4byte) >>> 4);
}
}
to.append(encodeTable[currentByte]);
}
}
if (to.length() % 4 != 0) {
for (i = 4 - to.length() % 4; i > 0; --i) {
to.append("=");
}
}
return to.toString();
}
}
3.文件转换工具类
/** * 文件读取工具类 */
public class FileUtil {
/** * 读取文件内容,作为字符串返回 */
public static String readFileAsString(String filePath) throws IOException {
File file = new File(filePath);
if (!file.exists()) {
throw new FileNotFoundException(filePath);
}
if (file.length() > 1024 * 1024 * 1024) {
throw new IOException("File is too large");
}
StringBuilder sb = new StringBuilder((int) (file.length()));
// 创建字节输入流
FileInputStream fis = new FileInputStream(filePath);
// 创建一个长度为10240的Buffer
byte[] bbuf = new byte[10240];
// 用于保存实际读取的字节数
int hasRead = 0;
while ((hasRead = fis.read(bbuf)) > 0) {
sb.append(new String(bbuf, 0, hasRead));
}
fis.close();
return sb.toString();
}
/** * 根据文件路径读取byte[] 数组 */
public static byte[] readFileByBytes(String filePath) throws IOException {
File file = new File(filePath);
if (!file.exists()) {
throw new FileNotFoundException(filePath);
} else {
ByteArrayOutputStream bos = new ByteArrayOutputStream((int) file.length());
BufferedInputStream in = null;
try {
in = new BufferedInputStream(new FileInputStream(file));
short bufSize = 1024;
byte[] buffer = new byte[bufSize];
int len1;
while (-1 != (len1 = in.read(buffer, 0, bufSize))) {
bos.write(buffer, 0, len1);
}
byte[] var7 = bos.toByteArray();
return var7;
} finally {
try {
if (in != null) {
in.close();
}
} catch (IOException var14) {
var14.printStackTrace();
}
bos.close();
}
}
}
/** * 将Url转换为File * * @param url * @return * @throws Exception */
public static File urlToFile(String url) throws Exception {
HttpURLConnection httpUrl = (HttpURLConnection) new URL(url).openConnection();
httpUrl.connect();
InputStream ins = httpUrl.getInputStream();
//System.getProperty("java.io.tmpdir")缓存
File file = new File(System.getProperty("java.io.tmpdir") + File.separator + "xie");
if (file.exists()) {
//如果缓存中存在该文件就删除
file.delete();
}
OutputStream os = new FileOutputStream(file);
int bytesRead;
int len = 8192;
byte[] buffer = new byte[len];
while ((bytesRead = ins.read(buffer, 0, len)) != -1) {
os.write(buffer, 0, bytesRead);
}
os.close();
ins.close();
return file;
}
/** * 将File对象转换为byte[]的形式 * * @param file * @return */
public static byte[] fileToByte(File file) {
FileInputStream fileInputStream = null;
byte[] imgData = null;
try {
imgData = new byte[(int) file.length()];
//read file into bytes[]
fileInputStream = new FileInputStream(file);
fileInputStream.read(imgData);
} catch (IOException e) {
e.printStackTrace();
} finally {
if (fileInputStream != null) {
try {
fileInputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return imgData;
}
}
4.Http工具类
/** * http 工具类 */
public class HttpUtil {
public static String post(String requestUrl, String accessToken, String params)
throws Exception {
String contentType = "application/x-www-form-urlencoded";
return HttpUtil.post(requestUrl, accessToken, contentType, params);
}
public static String post(String requestUrl, String accessToken, String contentType, String params)
throws Exception {
String encoding = "UTF-8";
if (requestUrl.contains("nlp")) {
encoding = "GBK";
}
return HttpUtil.post(requestUrl, accessToken, contentType, params, encoding);
}
public static String post(String requestUrl, String accessToken, String contentType, String params, String encoding)
throws Exception {
String url = requestUrl + "?access_token=" + accessToken;
return HttpUtil.postGeneralUrl(url, contentType, params, encoding);
}
public static String postGeneralUrl(String generalUrl, String contentType, String params, String encoding)
throws Exception {
URL url = new URL(generalUrl);
// 打开和URL之间的连接
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST");
// 设置通用的请求属性
connection.setRequestProperty("Content-Type", contentType);
connection.setRequestProperty("Connection", "Keep-Alive");
connection.setUseCaches(false);
connection.setDoOutput(true);
connection.setDoInput(true);
// 得到请求的输出流对象
DataOutputStream out = new DataOutputStream(connection.getOutputStream());
out.write(params.getBytes(encoding));
out.flush();
out.close();
// 建立实际的连接
connection.connect();
// 获取所有响应头字段
Map<String, List<String>> headers = connection.getHeaderFields();
// 遍历所有的响应头字段
for (String key : headers.keySet()) {
System.err.println(key + "--->" + headers.get(key));
}
// 定义 BufferedReader输入流来读取URL的响应
BufferedReader in = null;
in = new BufferedReader(
new InputStreamReader(connection.getInputStream(), encoding));
String result = "";
String getLine;
while ((getLine = in.readLine()) != null) {
result += getLine;
}
in.close();
System.err.println("result:" + result);
return result;
}
}
三.图片识别
图片识别工具类
@Slf4j
public class BaiduAiImageUtil {
// 参考:https://ai.baidu.com/ai-doc/IMAGERECOGNITION/Xk3bcxe21
/** * 重要提示代码中所需工具类 * FileUtil,Base64Util,HttpUtil,GsonUtils请从 * https://ai.baidu.com/file/658A35ABAB2D404FBF903F64D47C1F72 * https://ai.baidu.com/file/C8D81F3301E24D2892968F09AE1AD6E2 * https://ai.baidu.com/file/544D677F5D4E4F17B4122FBD60DB82B3 * https://ai.baidu.com/file/470B3ACCA3FE43788B5A963BF0B625F3 * 下载 */
//图片为文件格式
public static String advancedGeneralByImage(MultipartFile file, String accessToken) {
if (ObjectUtil.isEmpty(file) || StringUtils.isAnyBlank(accessToken)) {
return null;
}
// 请求url
String url = "https://aip.baidubce.com/rest/2.0/image-classify/v2/advanced_general";
try {
byte[] imgData = file.getBytes();
String imgStr = Base64Util.encode(imgData);
String imgParam = URLEncoder.encode(imgStr, "UTF-8");
String param = "image=" + imgParam;
String result = HttpUtil.post(url, accessToken, param);
log.info("上传图片:{},AI图像识别结果:{}", file, result);
return result;
} catch (Exception e) {
log.error("AI图像识别错误!错误原因:{}", e.getMessage());
e.printStackTrace(System.err);
}
return null;
}
//图片为URL格式
public static String advancedGeneralByUrl(String remoteUrl, String accessToken) {
if (StringUtils.isAnyBlank(remoteUrl, accessToken)) {
return null;
}
// 请求url
String url = "https://aip.baidubce.com/rest/2.0/image-classify/v2/advanced_general";
try {
String urlParam = URLEncoder.encode(remoteUrl, "UTF-8");
String param = "url=" + urlParam;
String result = HttpUtil.post(url, accessToken, param);
log.info("oss图片链接:{},AI图像识别结果:{}", remoteUrl, result);
return result;
} catch (Exception e) {
log.error("AI图像识别错误!错误原因:{}", e.getMessage());
e.printStackTrace(System.err);
}
return null;
}
}
接口
/** * ai垃圾分类-图像识别 * * @param file 图像文件 * @return Result<ComClassificationRespDTO> */
@PostMapping(value = {
"/image"})
public Result<List<AiRespDTO>> aiImage(@RequestParam("file") MultipartFile file) {
return success(iAiClassifyRecordService.aiImageV2(file));
}
实现类
@SneakyThrows
@Override
public List<ComClassificationRespDTO> aiImageV2(MultipartFile file) {
// 获取.后缀名
String ext = checkExt(file);
byte[] bytes = file.getBytes();
String encode = Base64.getEncoder().encodeToString(bytes);
String auth = getBaiDuAuth();
String aiResult = BaiduAiImageUtil.advancedGeneralByImage(file, auth);
AiImageResponse aiImageResponse=JSONObject.parseObject(aiResult, AiImageResponse.class);
//拿取关键词
List<AiImageResult> result = aiImageResponse.getResult();
}
百度接口数据返回格式
@Data
public class AiImageResponse {
// 参考:https://ai.baidu.com/ai-doc/IMAGERECOGNITION/Xk3bcxe21
/** * 唯一的log id,用于问题定位 */
private String log_id;
/** * 返回结果数目,及result数组中的元素个数,最多返回5个结果 */
private String result_num;
/** * 标签结果数组 */
private List<AiImageResult> result;
}
@Data
public class AiImageResult {
// 参考:https://ai.baidu.com/ai-doc/IMAGERECOGNITION/Xk3bcxe21
/** * 置信度,0-1 */
private float score;
/** * 识别结果的上层标签,有部分钱币、动漫、烟酒等tag无上层标签 */
private String root;
/** * 图片中的物体或场景名称 */
private String keyword;
}
检查图片格式
public String checkExt(MultipartFile file) {
String ext = FilenameUtils.getExtension(file.getOriginalFilename());
if ("jpg".equals(ext) || "png".equals(ext) || "bmp".equals(ext) || "jpeg".equals(ext)) {
return ext;
} else {
throw new ScException("只支持'jpg','png','bmp','jpeg'格式的图像文件");
}
}
四.语音识别
语音识别工具类
@Slf4j
public class BaiduAiVoiceUtil {
// 参考:https://ai.baidu.com/ai-doc/SPEECH/jkhq0ohzz
static final okhttp3.OkHttpClient HTTP_CLIENT = new okhttp3.OkHttpClient().newBuilder().build();
/** * 语音识别 */
public static String advancedGeneral(String remoteUrl, String userId, String accessToken, String ext) {
if (StringUtils.isAnyBlank(remoteUrl, userId, accessToken, ext)) {
return null;
}
// 请求url
String url = "https://vop.baidu.com/pro_api";
try {
File file = FileUtil.urlToFile(remoteUrl);
int length = (int) file.length();
byte[] voiceData = FileUtil.fileToByte(file);
String voiceStr = Base64Util.encode(voiceData);
AiVoiceRequest aiVoiceRequest = new AiVoiceRequest();
aiVoiceRequest.setFormat(ext);
aiVoiceRequest.setCuid(userId);
aiVoiceRequest.setToken(accessToken);
aiVoiceRequest.setSpeech(voiceStr);
aiVoiceRequest.setLen(length);
MediaType mediaType = MediaType.parse("application/json");
// speech 可以通过 getFileContentAsBase64("C:\fakepath\Skype_Notification.m4a") 方法获取
okhttp3.RequestBody body = okhttp3.RequestBody.create(mediaType, new JSONObject(aiVoiceRequest).toString());
Request request = new Request.Builder()
.url(url)
.method("POST", body)
.addHeader("Content-Type", "application/json")
.addHeader("Accept", "application/json")
.build();
okhttp3.Response response = HTTP_CLIENT.newCall(request).execute();
String result = response.body().string();
log.info("oss音频链接:{},AI语音识别结果:{}", remoteUrl, result);
return result;
} catch (Exception e) {
log.error("AI语音识别错误!错误原因:{}", e.getMessage());
e.printStackTrace(System.err);
}
return null;
}
/** * 获取文件base64编码 * * @param path 文件路径 * @return base64编码信息,不带文件头 * @throws IOException IO异常 */
static String getFileContentAsBase64(String path) throws IOException {
byte[] b = Files.readAllBytes(Paths.get(path));
return Base64.getEncoder().encodeToString(b);
}
}
接口
/** * ai垃圾分类-语音识别 * @param bean 条件 * @return Result<ComClassificationRespDTO> */
@PostMapping(value = {
"/voice"}, produces = {
MediaType.APPLICATION_JSON_UTF8_VALUE})
public List<ComClassificationRespDTO> aiVoice(@RequestBody @Validated AiClassifyRecordQueryDTO bean) {
return iAiClassifyRecordService.aiVoiceV2(bean);
}
实现类
public ComClassificationRespDTO aiVoice(MultipartFile file, Integer type) {
String curUid = Context.getUserID;
// 获取.后缀名
String ext = this.getAndCheckExt(file);
// 获取远程链接
String remoteUrl = this.uploadFile(file);
if (StringUtils.isBlank(remoteUrl)) {
throw new ScException(RespCode.UPLOAD_FILE_FAILED);
}
String auth = getBaiDuAuth();
String aiResult = BaiduAiVoiceUtil.advancedGeneral(remoteUrl, curUid, auth, ext);
AiVoiceResponse aiVoiceResponse = JSONObject.parseObject(aiResult, AiVoiceResponse.class);
List<String> result = aiVoiceResponse.getResult();
return null;
}
百度语音识别接口返回数据格式
@Data
public class AiVoiceResponse {
// 参考:https://ai.baidu.com/ai-doc/SPEECH/Ykhq0pqq8
/** * 错误码 */
private String err_no;
/** * 错误码描述 */
private String err_msg;
/** * 语音数据唯一标识,系统内部产生。如果反馈及debug请提供sn。 */
private String sn;
/** * 识别结果数组,返回1个最优候选结果。utf-8 编码。 */
private List<String> result;
}
百度语音识别接口请求入参数据格式
@Data
public class AiVoiceRequest {
// 参考:https://ai.baidu.com/ai-doc/SPEECH/Ikhq0parc
/** * 语音文件的格式,pcm/wav/amr/m4a。不区分大小写。推荐pcm文件 */
private String format;
/** * 采样率,16000,固定值 */
private Integer rate = 16000;
/** * 声道数,仅支持单声道,请填写固定值 1 */
private Integer channel = 1;
/** * 用户唯一标识,用来区分用户,计算UV值。建议填写能区分用户的机器 MAC 地址或 IMEI 码,长度为60字符以内。 */
private String cuid;
/** * 开放平台获取到的开发者[access_token]获取 Access Token "access_token") */
private String token;
/** * 80001(极速版输入法模型) */
private Integer dev_pid = 80001;
/** * 本地语音文件的的二进制语音数据 ,需要进行base64 编码。与len参数连一起使用。 */
private String speech;
/** * 本地语音文件的的字节数,单位字节 */
private Integer len;
}
检查语音格式
public String getAndCheckExt(MultipartFile file) {
String ext = FilenameUtils.getExtension(file.getOriginalFilename());
if ("pcm".equals(ext) || "wav".equals(ext) || "amr".equals(ext) || "m4a".equals(ext)) {
return ext;
} else {
throw new ScException("只支持'pcm','wav','amr','m4a'格式的音频文件");
}
}
文章评论